Source code for selene_sdk.utils.utils

"""
This module provides utility functions that are not tied to specific
classes or concepts, but still perform specific and important roles
across many of the packages modules.

"""
from collections import OrderedDict
import logging
import sys

import numpy as np

from .multi_model_wrapper import MultiModelWrapper


def _is_lua_trained_model(model):
    if hasattr(model, 'from_lua'):
        return model.from_lua
    check_model = model
    if hasattr(model, 'model'):
        check_model = model.model
    elif type(model) == MultiModelWrapper and \
            hasattr(model, 'sub_models'):
        check_model = model.sub_models[0]
    setattr(model, "from_lua", False)
    setattr(check_model, "from_lua", False)
    for m in check_model.modules():
        if "Conv2d" in m.__class__.__name__:
            setattr(model, "from_lua", True)
            setattr(check_model, "from_lua", True)
    return model.from_lua


[docs]def get_indices_and_probabilities(interval_lengths, indices): """ Given a list of different interval lengths and the indices of interest in that list, weight the probability that we will sample one of the indices in `indices` based on the interval lengths in that sublist. Parameters ---------- interval_lengths : list(int) The list of lengths of intervals that we will draw from. This is used to weight the indices proportionally to interval length. indices : list(int) The list of interval length indices to draw from. Returns ------- indices, weights : tuple(list(int), list(float)) \ Tuple of interval indices to sample from and the corresponding weights of those intervals. """ select_interval_lens = np.array(interval_lengths)[indices] weights = select_interval_lens / float(np.sum(select_interval_lens)) keep_indices = [] for index, weight in enumerate(weights): if weight > 1e-10: keep_indices.append(indices[index]) if len(keep_indices) == len(indices): return indices, weights.tolist() else: return get_indices_and_probabilities( interval_lengths, keep_indices)
[docs]def load_model_from_state_dict(state_dict, model): """ Loads model weights that were saved to a file previously by `torch.save`. This is a helper function to reconcile state dict keys where a model was saved with/without torch.nn.DataParallel and now must be loaded without/with torch.nn.DataParallel. Parameters ---------- state_dict : collections.OrderedDict The state of the model. model : torch.nn.Module The PyTorch model, a module composed of submodules. Returns ------- torch.nn.Module \ The model with weights loaded from the state dict. Raises ------ ValueError If model state dict keys do not match the keys in `state_dict`. """ if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] model_keys = model.state_dict().keys() state_dict_keys = state_dict.keys() if len(model_keys) != len(state_dict_keys): try: model.load_state_dict(state_dict, strict=False) return model except Exception as e: raise ValueError("Loaded state dict does not match the model " "architecture specified - please check that you are " "using the correct architecture file and parameters.\n\n" "{0}".format(e)) new_state_dict = OrderedDict() for (k1, k2) in zip(model_keys, state_dict_keys): value = state_dict[k2] try: new_state_dict[k1] = value except Exception as e: raise ValueError( "Failed to load weight from module {0} in model weights " "into model architecture module {1}. (If module name has " "an additional prefix `model.` it is because the model is " "wrapped in `selene_sdk.utils.NonStrandSpecific`. This " "error was raised because the underlying module does " "not match that expected by the loaded model:\n" "{2}".format(k2, k1, e)) model.load_state_dict(new_state_dict) return model
[docs]def load_features_list(input_path): """ Reads in a file of distinct feature names line-by-line and returns these features as a list. Each feature name in the file must occur on a separate line. Parameters ---------- input_path : str Path to the features file. Each feature in the input file must be on its own line. Returns ------- list(str) \ The list of features. The features will appear in the list in the same order they appeared in the file (reading from top to bottom). Examples -------- A file at "input_features.txt", for the feature names :math:`YFP` and :math:`YFG` might look like this: :: YFP YFG We can load these features from that file as follows: >>> load_features_list("input_features.txt") ["YFP", "YFG"] """ features = [] with open(input_path, 'r') as file_handle: for line in file_handle: features.append(line.strip()) return features
[docs]def initialize_logger(output_path, verbosity=2): """ Initializes the logger for Selene. This function can only be called successfully once. If the logger has already been initialized with handlers, the function exits. Otherwise, it proceeds to set the logger configurations. Parameters ---------- output_path : str The path to the output file where logs will be written. verbosity : int, {2, 1, 0} Default is 2. The level of logging verbosity to use. * 0 - Only warnings will be logged. * 1 - Information and warnings will be logged. * 2 - Debug messages, information, and warnings will all be\ logged. """ logger = logging.getLogger("selene") # check if logger has already been initialized if len(logger.handlers): return if verbosity == 0: logger.setLevel(logging.WARN) elif verbosity == 1: logger.setLevel(logging.INFO) elif verbosity == 2: logger.setLevel(logging.DEBUG) file_formatter = logging.Formatter( "%(asctime)s - %(levelname)s - %(message)s") file_handle = logging.FileHandler(output_path) file_handle.setFormatter(file_formatter) logger.addHandler(file_handle) stdout_formatter = logging.Formatter( "%(asctime)s - %(message)s") stdout_handle = logging.StreamHandler(sys.stdout) stdout_handle.setFormatter(stdout_formatter) stdout_handle.setLevel(logging.INFO) logger.addHandler(stdout_handle)