Source code for selene_sdk.train_model

"""
This module provides the `TrainModel` class and supporting methods.
"""
import logging
import math
import os
import shutil
from time import strftime
from time import time

import numpy as np
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score

from .utils import initialize_logger
from .utils import load_model_from_state_dict
from .utils import PerformanceMetrics

logger = logging.getLogger("selene")


def _metrics_logger(name, out_filepath):
    logger = logging.getLogger("{0}".format(name))
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter("%(message)s")
    file_handle = logging.FileHandler(
        os.path.join(out_filepath, "{0}.txt".format(name)))
    file_handle.setFormatter(formatter)
    logger.addHandler(file_handle)
    return logger


[docs]class TrainModel(object): """ This class ties together the various objects and methods needed to train and validate a model. TrainModel saves a checkpoint model (overwriting it after `save_checkpoint_every_n_steps`) as well as a best-performing model (overwriting it after `report_stats_every_n_steps` if the latest validation performance is better than the previous best-performing model) to `output_dir`. TrainModel also outputs 2 files that can be used to monitor training as Selene runs: `selene_sdk.train_model.train.txt` (training loss) and `selene_sdk.train_model.validation.txt` (validation loss & average ROC AUC). The columns in these files can be used to quickly visualize training history (e.g. you can use `matplotlib`, `plt.plot(auc_list)`) and see, for example, whether the model is still improving, if there are signs of overfitting, etc. Parameters ---------- model : torch.nn.Module The model to train. data_sampler : selene_sdk.samplers.Sampler The example generator. loss_criterion : torch.nn._Loss The loss function to optimize. optimizer_class : torch.optim.Optimizer The optimizer to minimize loss with. optimizer_kwargs : dict The dictionary of keyword arguments to pass to the optimizer's constructor. batch_size : int Specify the batch size to process examples. Should be a power of 2. max_steps : int The maximum number of mini-batches to iterate over. report_stats_every_n_steps : int The frequency with which to report summary statistics. You can set this value to be equivalent to a training epoch (`n_steps * batch_size`) being the total number of samples seen by the model so far. Selene evaluates the model on the validation dataset every `report_stats_every_n_steps` and, if the model obtains the best performance so far (based on the user-specified loss function), Selene saves the model state to a file called `best_model.pth.tar` in `output_dir`. output_dir : str The output directory to save model checkpoints and logs in. save_checkpoint_every_n_steps : int or None, optional Default is 1000. If None, set to the same value as `report_stats_every_n_steps` save_new_checkpoints_after_n_steps : int or None, optional Default is None. The number of steps after which Selene will continually save new checkpoint model weights files (`checkpoint-<TIMESTAMP>.pth.tar`) every `save_checkpoint_every_n_steps`. Before this point, the file `checkpoint.pth.tar` is overwritten every `save_checkpoint_every_n_steps` to limit the memory requirements. n_validation_samples : int or None, optional Default is `None`. Specify the number of validation samples in the validation set. If `n_validation_samples` is `None` and the data sampler used is the `selene_sdk.samplers.IntervalsSampler` or `selene_sdk.samplers.RandomSampler`, we will retrieve 32000 validation samples. If `None` and using `selene_sdk.samplers.MultiSampler`, we will use all available validation samples from the appropriate data file. n_test_samples : int or None, optional Default is `None`. Specify the number of test samples in the test set. If `n_test_samples` is `None` and - the sampler you specified has no test partition, you should not specify `evaluate` as one of the operations in the `ops` list. That is, Selene will not automatically evaluate your trained model on a test dataset, because the sampler you are using does not have any test data. - the sampler you use is of type `selene_sdk.samplers.OnlineSampler` (and the test partition exists), we will retrieve 640000 test samples. - the sampler you use is of type `selene_sdk.samplers.MultiSampler` (and the test partition exists), we will use all the test samples available in the appropriate data file. cpu_n_threads : int, optional Default is 1. Sets the number of OpenMP threads used for parallelizing CPU operations. use_cuda : bool, optional Default is `False`. Specify whether a CUDA-enabled GPU is available for torch to use during training. data_parallel : bool, optional Default is `False`. Specify whether multiple GPUs are available for torch to use during training. logging_verbosity : {0, 1, 2}, optional Default is 2. Set the logging verbosity level. * 0 - Only warnings will be logged. * 1 - Information and warnings will be logged. * 2 - Debug messages, information, and warnings will all be\ logged. checkpoint_resume : str or None, optional Default is `None`. If `checkpoint_resume` is not None, it should be the path to a model file generated by `torch.save` that can now be read using `torch.load`. use_scheduler : bool, optional Default is `True`. If `True`, learning rate scheduler is used to reduce learning rate on plateau. PyTorch ReduceLROnPlateau scheduler with patience=16 and factor=0.8 is used. Different scheduler parameters can be specified with `scheduler_kwargs`. deterministic : bool, optional Default is `False`. If `True`, will set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark = False`. In Selene CLI, if `random_seed` is set in the configuration YAML, Selene automatically passes in `deterministic=True` to the TrainModel class. scheduler_kwargs : dict, optional Default is patience=16, verbose=True, and factor=0.8. Set the parameters for the PyTorch ReduceLROnPlateau scheduler. stopping_criteria : list or None, optional Default is `None`. If `stopping_criteria` is not None, it should be a list specifying how to use early stopping. The first value should be a str corresponding to one of `metrics`. The second value should be an int indicating the patience. If the specified metric does not improve in the given patience (usually corresponding to the number of epochs), training stops early. Attributes ---------- model : torch.nn.Module The model to train. sampler : selene_sdk.samplers.Sampler The example generator. criterion : torch.nn._Loss The loss function to optimize. optimizer : torch.optim.Optimizer The optimizer to minimize loss with. batch_size : int The size of the mini-batch to use during training. max_steps : int The maximum number of mini-batches to iterate over. nth_step_report_stats : int The frequency with which to report summary statistics. nth_step_save_checkpoint : int The frequency with which to save a model checkpoint. use_cuda : bool If `True`, use a CUDA-enabled GPU. If `False`, use the CPU. data_parallel : bool Whether to use multiple GPUs or not. output_dir : str The directory to save model checkpoints and logs. """ def __init__(self, model, data_sampler, loss_criterion, optimizer_class, optimizer_kwargs, batch_size, max_steps, report_stats_every_n_steps, output_dir, save_checkpoint_every_n_steps=1000, save_new_checkpoints_after_n_steps=None, report_gt_feature_n_positives=10, n_validation_samples=None, n_test_samples=None, cpu_n_threads=1, use_cuda=False, data_parallel=False, logging_verbosity=2, checkpoint_resume=None, metrics=dict(roc_auc=roc_auc_score, average_precision=average_precision_score), use_scheduler=True, deterministic=False, scheduler_kwargs=dict(patience=16, verbose=True, factor=0.8), stopping_criteria=None): """ Constructs a new `TrainModel` object. """ self.model = model self.sampler = data_sampler self.criterion = loss_criterion self.optimizer = optimizer_class( self.model.parameters(), **optimizer_kwargs) self.batch_size = batch_size self.max_steps = max_steps self.nth_step_report_stats = report_stats_every_n_steps self.nth_step_save_checkpoint = None if not save_checkpoint_every_n_steps: self.nth_step_save_checkpoint = report_stats_every_n_steps else: self.nth_step_save_checkpoint = save_checkpoint_every_n_steps self._save_new_checkpoints = save_new_checkpoints_after_n_steps os.makedirs(output_dir, exist_ok=True) self.output_dir = output_dir initialize_logger( os.path.join(self.output_dir, "{0}.log".format(__name__)), verbosity=logging_verbosity) if deterministic: logger.info("Setting deterministic = True for reproducibility.") torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False logger.info("Training parameters set: batch size {0}, " "number of steps per 'epoch': {1}, " "maximum number of steps: {2}".format( self.batch_size, self.nth_step_report_stats, self.max_steps)) torch.set_num_threads(cpu_n_threads) self.use_cuda = use_cuda self.data_parallel = data_parallel if self.data_parallel: self.model = nn.DataParallel(model) logger.debug("Wrapped model in DataParallel") if self.use_cuda: self.model.cuda() self.criterion.cuda() logger.debug("Set modules to use CUDA") self._report_gt_feature_n_positives = report_gt_feature_n_positives self._metrics = metrics self._n_validation_samples = n_validation_samples self._n_test_samples = n_test_samples self._use_scheduler = use_scheduler self._init_train(scheduler_kwargs) self._init_validate() if "test" in self.sampler.modes: self._init_test() if checkpoint_resume is not None: self._load_checkpoint(checkpoint_resume) if type(stopping_criteria) is list and len(stopping_criteria) == 2: stopping_metric, stopping_patience = stopping_criteria self._early_stopping = True if stopping_metric in self._metrics: self._stopping_metric = stopping_metric self._stopping_patience = stopping_patience self._stopping_reached = False else: logger.warning("Did not recognize stopping metric. Not performing early stopping.") self._early_stopping = False else: self._early_stopping = False def _load_checkpoint(self, checkpoint_resume): checkpoint = torch.load( checkpoint_resume, map_location=lambda storage, location: storage) if "state_dict" not in checkpoint: raise ValueError( ("'state_dict' not found in file {0} " "loaded with method `torch.load`. Selene does not support " "continued training of models that were not originally " "trained using Selene.").format(checkpoint_resume)) self.model = load_model_from_state_dict( checkpoint["state_dict"], self.model) self._start_step = checkpoint["step"] if self._start_step >= self.max_steps: self.max_steps += self._start_step self._min_loss = checkpoint["min_loss"] self.optimizer.load_state_dict( checkpoint["optimizer"]) if self.use_cuda: for state in self.optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda() logger.info( ("Resuming from checkpoint: step {0}, min loss {1}").format( self._start_step, self._min_loss)) def _init_train(self, scheduler_kwargs): self._start_step = 0 self._train_logger = _metrics_logger( "{0}.train".format(__name__), self.output_dir) self._train_logger.info("loss") if self._use_scheduler: self.scheduler = ReduceLROnPlateau( self.optimizer, 'min', **scheduler_kwargs) self._time_per_step = [] self._train_loss = [] def _init_validate(self): self._min_loss = float("inf") # TODO: Should this be set when it is used later? Would need to if we want to train model 2x in one run. self._create_validation_set(n_samples=self._n_validation_samples) self._validation_metrics = PerformanceMetrics( self.sampler.get_feature_from_index, report_gt_feature_n_positives=self._report_gt_feature_n_positives, metrics=self._metrics) self._validation_logger = _metrics_logger( "{0}.validation".format(__name__), self.output_dir) self._validation_logger.info("\t".join(["loss"] + sorted([x for x in self._validation_metrics.metrics.keys()]))) def _init_test(self): self._test_data = None self._n_test_samples = self._n_test_samples self._test_metrics = PerformanceMetrics( self.sampler.get_feature_from_index, report_gt_feature_n_positives=self._report_gt_feature_n_positives, metrics=self._metrics) def _create_validation_set(self, n_samples=None): """ Generates the set of validation examples. Parameters ---------- n_samples : int or None, optional Default is `None`. The size of the validation set. If `None`, will use all validation examples in the sampler. """ logger.info("Creating validation dataset.") t_i = time() self._validation_data, self._all_validation_targets = \ self.sampler.get_validation_set( self.batch_size, n_samples=n_samples) t_f = time() logger.info(("{0} s to load {1} validation examples ({2} validation " "batches) to evaluate after each training step.").format( t_f - t_i, len(self._validation_data) * self.batch_size, len(self._validation_data)))
[docs] def create_test_set(self): """ Loads the set of test samples. We do not create the test set in the `TrainModel` object until this method is called, so that we avoid having to load it into memory until the model has been trained and is ready to be evaluated. """ logger.info("Creating test dataset.") t_i = time() self._test_data, self._all_test_targets = \ self.sampler.get_test_set( self.batch_size, n_samples=self._n_test_samples) t_f = time() logger.info(("{0} s to load {1} test examples ({2} test batches) " "to evaluate after all training steps.").format( t_f - t_i, len(self._test_data) * self.batch_size, len(self._test_data))) np.savez_compressed( os.path.join(self.output_dir, "test_targets.npz"), data=self._all_test_targets)
def _get_batch(self): """ Fetches a mini-batch of examples Returns ------- tuple(numpy.ndarray, numpy.ndarray) A tuple containing the examples and targets. """ t_i_sampling = time() batch_sequences, batch_targets = self.sampler.sample( batch_size=self.batch_size) t_f_sampling = time() logger.debug( ("[BATCH] Time to sample {0} examples: {1} s.").format( self.batch_size, t_f_sampling - t_i_sampling)) return (batch_sequences, batch_targets) def _checkpoint(self): checkpoint_dict = { "step": self.step, "arch": self.model.__class__.__name__, "state_dict": self.model.state_dict(), "min_loss": self._min_loss, "optimizer": self.optimizer.state_dict() } if self._save_new_checkpoints is not None and \ self._save_new_checkpoints >= self.step: checkpoint_filename = "checkpoint-{0}".format( strftime("%m%d%H%M%S")) self._save_checkpoint( checkpoint_dict, False, filename=checkpoint_filename) logger.debug("Saving checkpoint `{0}.pth.tar`".format( checkpoint_filename)) else: self._save_checkpoint( checkpoint_dict, False)
[docs] def train_and_validate(self): """ Trains the model and measures validation performance. """ for step in range(self._start_step, self.max_steps): self.step = step self.train() if step % self.nth_step_save_checkpoint == 0: self._checkpoint() if self.step and self.step % self.nth_step_report_stats == 0: self.validate() if self._early_stopping and self._stopping_reached: logger.debug("Patience ran out. Stopping early.") break self.sampler.save_dataset_to_file("train", close_filehandle=True)
[docs] def train(self): """ Trains the model on a batch of data. Returns ------- float The training loss. """ t_i = time() self.model.train() self.sampler.set_mode("train") inputs, targets = self._get_batch() inputs = torch.Tensor(inputs) targets = torch.Tensor(targets) if self.use_cuda: inputs = inputs.cuda() targets = targets.cuda() predictions = self.model(inputs.transpose(1, 2)) loss = self.criterion(predictions, targets) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self._train_loss.append(loss.item()) t_f = time() self._time_per_step.append(t_f - t_i) if self.step and self.step % self.nth_step_report_stats == 0: logger.info(("[STEP {0}] average number " "of steps per second: {1:.1f}").format( self.step, 1. / np.average(self._time_per_step))) self._train_logger.info(np.average(self._train_loss)) logger.info("training loss: {0}".format(np.average(self._train_loss))) self._time_per_step = [] self._train_loss = []
def _evaluate_on_data(self, data_in_batches): """ Makes predictions for some labeled input data. Parameters ---------- data_in_batches : list(tuple(numpy.ndarray, numpy.ndarray)) A list of tuples of the data, where the first element is the example, and the second element is the label. Returns ------- tuple(float, list(numpy.ndarray)) Returns the average loss, and the list of all predictions. """ self.model.eval() batch_losses = [] all_predictions = [] for (inputs, targets) in data_in_batches: inputs = torch.Tensor(inputs) targets = torch.Tensor(targets) if self.use_cuda: inputs = inputs.cuda() targets = targets.cuda() with torch.no_grad(): predictions = self.model(inputs.transpose(1, 2)) loss = self.criterion(predictions, targets) all_predictions.append( predictions.data.cpu().numpy()) batch_losses.append(loss.item()) all_predictions = np.vstack(all_predictions) return np.average(batch_losses), all_predictions
[docs] def validate(self): """ Measures model validation performance. Returns ------- dict A dictionary, where keys are the names of the loss metrics, and the values are the average value for that metric over the validation set. """ validation_loss, all_predictions = self._evaluate_on_data( self._validation_data) valid_scores = self._validation_metrics.update( all_predictions, self._all_validation_targets) for name, score in valid_scores.items(): logger.info("validation {0}: {1}".format(name, score)) valid_scores["loss"] = validation_loss to_log = [str(validation_loss)] for k in sorted(self._validation_metrics.metrics.keys()): if k in valid_scores and valid_scores[k]: to_log.append(str(valid_scores[k])) else: to_log.append("NA") self._validation_logger.info("\t".join(to_log)) # scheduler update if self._use_scheduler: self.scheduler.step( math.ceil(validation_loss * 1000.0) / 1000.0) # save best_model if validation_loss < self._min_loss: self._min_loss = validation_loss self._save_checkpoint({ "step": self.step, "arch": self.model.__class__.__name__, "state_dict": self.model.state_dict(), "min_loss": self._min_loss, "optimizer": self.optimizer.state_dict()}, True) logger.debug("Updating `best_model.pth.tar`") logger.info("validation loss: {0}".format(validation_loss)) # check for early stopping if self._early_stopping: stopping_metric = self._validation_metrics.metrics[self._stopping_metric].data index = np.argmax(stopping_metric) if self._stopping_patience - (len(stopping_metric) - index - 1) <= 0: self._stopping_reached = True
[docs] def evaluate(self): """ Measures the model test performance. Returns ------- dict A dictionary, where keys are the names of the loss metrics, and the values are the average value for that metric over the test set. """ if self._test_data is None: self.create_test_set() average_loss, all_predictions = self._evaluate_on_data( self._test_data) average_scores = self._test_metrics.update(all_predictions, self._all_test_targets) np.savez_compressed( os.path.join(self.output_dir, "test_predictions.npz"), data=all_predictions) for name, score in average_scores.items(): logger.info("test {0}: {1}".format(name, score)) test_performance = os.path.join( self.output_dir, "test_performance.txt") feature_scores_dict = self._test_metrics.write_feature_scores_to_file( test_performance) average_scores["loss"] = average_loss self._test_metrics.visualize( all_predictions, self._all_test_targets, self.output_dir) return (average_scores, feature_scores_dict)
def _save_checkpoint(self, state, is_best, filename="checkpoint"): """ Saves snapshot of the model state to file. Will save a checkpoint with name `<filename>.pth.tar` and, if this is the model's best performance so far, will save the state to a `best_model.pth.tar` file as well. Models are saved in the state dictionary format. This is a more stable format compared to saving the whole model (which is another option supported by PyTorch). Note that we do save a number of additional, Selene-specific parameters in the dictionary and that the actual `model.state_dict()` is stored in the `state_dict` key of the dictionary loaded by `torch.load`. See: https://pytorch.org/docs/stable/notes/serialization.html for more information about how models are saved in PyTorch. Parameters ---------- state : dict Information about the state of the model. Note that this is not `model.state_dict()`, but rather, a dictionary containing keys that can be used for continued training in Selene _in addition_ to a key `state_dict` that contains `model.state_dict()`. is_best : bool Is this the model's best performance so far? filename : str, optional Default is "checkpoint". Specify the checkpoint filename. Will append a file extension to the end of the `filename` (e.g. `checkpoint.pth.tar`). Returns ------- None """ logger.debug("[TRAIN] {0}: Saving model state to file.".format( state["step"])) cp_filepath = os.path.join( self.output_dir, filename) torch.save(state, "{0}.pth.tar".format(cp_filepath)) if is_best: best_filepath = os.path.join(self.output_dir, "best_model") shutil.copyfile("{0}.pth.tar".format(cp_filepath), "{0}.pth.tar".format(best_filepath))