Source code for selene_sdk.utils.non_strand_specific_module

"""
This module provides the NonStrandSpecific class.
"""
import torch
from torch.nn.modules import Module

from . import _is_lua_trained_model


def _flip(x, dim):
    """
    Reverses the elements in a given dimension `dim` of the Tensor.

    source: https://github.com/pytorch/pytorch/issues/229
    """
    xsize = x.size()
    dim = x.dim() + dim if dim < 0 else dim
    x = x.contiguous()
    x = x.view(-1, *xsize[dim:])
    x = x.view(
        x.size(0), x.size(1), -1)[:, getattr(
            torch.arange(x.size(1)-1, -1, -1),
            ('cpu','cuda')[x.is_cuda])().long(), :]
    return x.view(xsize)


[docs]class NonStrandSpecific(Module): """ A torch.nn.Module that wraps a user-specified model architecture if the architecture does not need to account for sequence strand-specificity. Parameters ---------- model : torch.nn.Module The user-specified model architecture. mode : {'mean', 'max'}, optional Default is 'mean'. NonStrandSpecific will pass the input and the reverse-complement of the input into `model`. The mode specifies whether we should output the mean or max of the predictions as the non-strand specific prediction. Attributes ---------- model : torch.nn.Module The user-specified model architecture. mode : {'mean', 'max'} How to handle outputting a non-strand specific prediction. """ def __init__(self, model, mode="mean"): super(NonStrandSpecific, self).__init__() self.model = model if mode != "mean" and mode != "max": raise ValueError("Mode should be one of 'mean' or 'max' but was" "{0}.".format(mode)) self.mode = mode self.from_lua = _is_lua_trained_model(model) def forward(self, input): reverse_input = None if self.from_lua: reverse_input = _flip( _flip(torch.squeeze(input, 2), 1), 2).unsqueeze_(2) else: reverse_input = _flip(_flip(input, 1), 2) output = self.model.forward(input) output_from_rev = self.model.forward(reverse_input) if self.mode == "mean": return (output + output_from_rev) / 2 else: return torch.max(output, output_from_rev)