Source code for selene_sdk.samplers.intervals_sampler

"""
This module provides the `IntervalsSampler` class and supporting
methods.
"""
from collections import namedtuple
import logging
import random

import numpy as np

from .online_sampler import OnlineSampler
from ..utils import get_indices_and_probabilities

logger = logging.getLogger(__name__)


SampleIndices = namedtuple(
    "SampleIndices", ["indices", "weights"])
"""
A tuple containing the indices for some samples, and a weight to
allot to each index when randomly drawing from them.

Parameters
----------
indices : list(int)
    The numeric index of each sample.
weights : list(float)
    The amount of weight assigned to each sample.

Attributes
----------
indices : list(int)
    The numeric index of each sample.
weights : list(float)
    The amount of weight assigned to each sample.

"""


# @TODO: Extend this class to work with stranded data.
[docs]class IntervalsSampler(OnlineSampler): """ Draws samples from pre-specified windows in the reference sequence. Parameters ---------- reference_sequence : selene_sdk.sequences.Sequence A reference sequence from which to create examples. target_path : str Path to tabix-indexed, compressed BED file (`*.bed.gz`) of genomic coordinates mapped to the genomic features we want to predict. features : list(str) List of distinct features that we aim to predict. intervals_path : str The path to the file that contains the intervals to sample from. In this file, each interval should occur on a separate line. sample_negative : bool, optional Default is `False`. This tells the sampler whether negative examples (i.e. with no positive labels) should be drawn when generating samples. If `True`, both negative and positive samples will be drawn. If `False`, only samples with at least one positive label will be drawn. seed : int, optional Default is 436. Sets the random seed for sampling. validation_holdout : list(str) or float, optional Default is `['chr6', 'chr7']`. Holdout can be regional or proportional. If regional, expects a list (e.g. `['X', 'Y']`). Regions must match those specified in the first column of the tabix-indexed BED file. If proportional, specify a percentage between (0.0, 1.0). Typically 0.10 or 0.20. test_holdout : list(str) or float, optional Default is `['chr8', 'chr9']`. See documentation for `validation_holdout` for additional information. sequence_length : int, optional Default is 1000. Model is trained on sequences of `sequence_length` where genomic features are annotated to the center regions of these sequences. center_bin_to_predict : int, optional Default is 200. Query the tabix-indexed file for a region of length `center_bin_to_predict`. feature_thresholds : float [0.0, 1.0] or None, optional Default is 0.5. The `feature_threshold` to pass to the `GenomicFeatures` object. mode : {'train', 'validate', 'test'} Default is `'train'`. The mode to run the sampler in. save_datasets : list of str Default is `["test"]`. The list of modes for which we should save the sampled data to file. output_dir : str or None, optional Default is None. The path to the directory where we should save sampled examples for a mode. If `save_datasets` is a non-empty list, `output_dir` must be specified. If the path in `output_dir` does not exist it will be created automatically. Attributes ---------- reference_sequence : selene_sdk.sequences.Sequence The reference sequence that examples are created from. target : selene_sdk.targets.Target The `selene_sdk.targets.Target` object holding the features that we would like to predict. sample_from_intervals : list(tuple(str, int, int)) A list of coordinates that specify the intervals we can draw samples from. interval_lengths : list(int) A list of the lengths of the intervals that we can draw samples from. The probability that we will draw a sample from an interval is a function of that interval's length and the length of all other intervals. sample_negative : bool Whether negative examples (i.e. with no positive label) should be drawn when generating samples. If `True`, both negative and positive samples will be drawn. If `False`, only samples with at least one positive label will be drawn. validation_holdout : list(str) or float The samples to hold out for validating model performance. These can be "regional" or "proportional". If regional, this is a list of region names (e.g. `['chrX', 'chrY']`). These Regions must match those specified in the first column of the tabix-indexed BED file. If proportional, this is the fraction of total samples that will be held out. test_holdout : list(str) or float The samples to hold out for testing model performance. See the documentation for `validation_holdout` for more details. sequence_length : int The length of the sequences to train the model on. bin_radius : int From the center of the sequence, the radius in which to detect a feature annotation in order to include it as a sample's label. surrounding_sequence_radius : int The length of sequence falling outside of the feature detection bin (i.e. `bin_radius`) center, but still within the `sequence_length`. modes : list(str) The list of modes that the sampler can be run in. mode : str The current mode that the sampler is running in. Must be one of the modes listed in `modes`. """ def __init__(self, reference_sequence, target_path, features, intervals_path, sample_negative=False, seed=436, validation_holdout=['chr6', 'chr7'], test_holdout=['chr8', 'chr9'], sequence_length=1000, center_bin_to_predict=200, feature_thresholds=0.5, mode="train", save_datasets=["test"], output_dir=None): """ Constructs a new `IntervalsSampler` object. """ super(IntervalsSampler, self).__init__( reference_sequence, target_path, features, seed=seed, validation_holdout=validation_holdout, test_holdout=test_holdout, sequence_length=sequence_length, center_bin_to_predict=center_bin_to_predict, feature_thresholds=feature_thresholds, mode=mode, save_datasets=save_datasets, output_dir=output_dir) self._sample_from_mode = {} self._randcache = {} for mode in self.modes: self._sample_from_mode[mode] = None self._randcache[mode] = {"cache_indices": None, "sample_next": 0} self.sample_from_intervals = [] self.interval_lengths = [] if self._holdout_type == "chromosome": self._partition_dataset_chromosome(intervals_path) else: self._partition_dataset_proportion(intervals_path) for mode in self.modes: self._update_randcache(mode=mode) self.sample_negative = sample_negative def _partition_dataset_proportion(self, intervals_path): """ When holdout sets are created by randomly sampling a proportion of the data, this method is used to divide the data into train/test/validate subsets. Parameters ---------- intervals_path : str The path to the file that contains the intervals to sample from. In this file, each interval should occur on a separate line. """ with open(intervals_path, 'r') as file_handle: for line in file_handle: cols = line.strip().split('\t') chrom = cols[0] start = int(cols[1]) end = int(cols[2]) self.sample_from_intervals.append((chrom, start, end)) self.interval_lengths.append(end - start) n_intervals = len(self.sample_from_intervals) # all indices in the intervals list are shuffled select_indices = list(range(n_intervals)) np.random.shuffle(select_indices) # the first section of indices is used as the validation set n_indices_validate = int(n_intervals * self.validation_holdout) val_indices, val_weights = get_indices_and_probabilities( self.interval_lengths, select_indices[:n_indices_validate]) self._sample_from_mode["validate"] = SampleIndices( val_indices, val_weights) if self.test_holdout: # if applicable, the second section of indices is used as the # test set n_indices_test = int(n_intervals * self.test_holdout) test_indices_end = n_indices_test + n_indices_validate test_indices, test_weights = get_indices_and_probabilities( self.interval_lengths, select_indices[n_indices_validate:test_indices_end]) self._sample_from_mode["test"] = SampleIndices( test_indices, test_weights) # remaining indices are for the training set tr_indices, tr_weights = get_indices_and_probabilities( self.interval_lengths, select_indices[test_indices_end:]) self._sample_from_mode["train"] = SampleIndices( tr_indices, tr_weights) else: # remaining indices are for the training set tr_indices, tr_weights = get_indices_and_probabilities( self.interval_lengths, select_indices[n_indices_validate:]) self._sample_from_mode["train"] = SampleIndices( tr_indices, tr_weights) def _partition_dataset_chromosome(self, intervals_path): """ When holdout sets are created by selecting all samples from a specified region (e.g. a chromosome) this method is used to divide the data into train/test/validate subsets. Parameters ---------- intervals_path : str The path to the file that contains the intervals to sample from. In this file, each interval should occur on a separate line. """ for mode in self.modes: self._sample_from_mode[mode] = SampleIndices([], []) with open(intervals_path, 'r') as file_handle: for index, line in enumerate(file_handle): cols = line.strip().split('\t') chrom = cols[0] start = int(cols[1]) end = int(cols[2]) if chrom in self.validation_holdout: self._sample_from_mode["validate"].indices.append( index) elif self.test_holdout and chrom in self.test_holdout: self._sample_from_mode["test"].indices.append( index) else: self._sample_from_mode["train"].indices.append( index) self.sample_from_intervals.append((chrom, start, end)) self.interval_lengths.append(end - start) for mode in self.modes: sample_indices = self._sample_from_mode[mode].indices indices, weights = get_indices_and_probabilities( self.interval_lengths, sample_indices) self._sample_from_mode[mode] = \ self._sample_from_mode[mode]._replace( indices=indices, weights=weights) def _retrieve(self, chrom, position): """ Retrieves samples around a position in the `reference_sequence`. Parameters ---------- chrom : str The name of the region (e.g. "chrX", "YFP") position : int The position in the query region that we will search around for samples. Returns ------- retrieved_seq, retrieved_targets : \ tuple(numpy.ndarray, list(numpy.ndarray)) A tuple containing the numeric representation of the sequence centered at the query position, as well as a list of samples within this region that met the filtering standards. """ bin_start = position - self._start_radius bin_end = position + self._end_radius retrieved_targets = self.target.get_feature_data( chrom, bin_start, bin_end) if not self.sample_negative and np.sum(retrieved_targets) == 0: logger.info("No features found in region surrounding " "region \"{0}\" position {1}. Sampling again.".format( chrom, position)) return None window_start = bin_start - self.surrounding_sequence_radius window_end = bin_end + self.surrounding_sequence_radius strand = self.STRAND_SIDES[random.randint(0, 1)] retrieved_seq = \ self.reference_sequence.get_encoding_from_coords( chrom, window_start, window_end, strand) if retrieved_seq.shape[0] == 0: logger.info("Full sequence centered at region \"{0}\" position " "{1} could not be retrieved. Sampling again.".format( chrom, position)) return None elif np.sum(retrieved_seq) / float(retrieved_seq.shape[0]) < 0.60: logger.info("Over 30% of the bases in the sequence centered " "at region \"{0}\" position {1} are ambiguous ('N'). " "Sampling again.".format(chrom, position)) return None if self.mode in self._save_datasets: feature_indices = ';'.join( [str(f) for f in np.nonzero(retrieved_targets)[0]]) self._save_datasets[self.mode].append( [chrom, window_start, window_end, strand, feature_indices]) if len(self._save_datasets[self.mode]) > 200000: self.save_dataset_to_file(self.mode) return (retrieved_seq, retrieved_targets) def _update_randcache(self, mode=None): """ Updates the cache of indices of intervals. This allows us to randomly sample from our data without having to use a fixed-point approach or keeping all labels in memory. Parameters ---------- mode : str or None, optional Default is `None`. The mode that these samples should be used for. See `selene_sdk.samplers.IntervalsSampler.modes` for more information. """ if not mode: mode = self.mode self._randcache[mode]["cache_indices"] = np.random.choice( self._sample_from_mode[mode].indices, size=len(self._sample_from_mode[mode].indices), replace=True, p=self._sample_from_mode[mode].weights) self._randcache[mode]["sample_next"] = 0
[docs] def sample(self, batch_size=1): """ Randomly draws a mini-batch of examples and their corresponding labels. Parameters ---------- batch_size : int, optional Default is 1. The number of examples to include in the mini-batch. Returns ------- sequences, targets : tuple(numpy.ndarray, numpy.ndarray) A tuple containing the numeric representation of the sequence examples and their corresponding labels. The shape of `sequences` will be :math:`B \\times L \\times N`, where :math:`B` is `batch_size`, :math:`L` is the sequence length, and :math:`N` is the size of the sequence type's alphabet. The shape of `targets` will be :math:`B \\times F`, where :math:`F` is the number of features. """ sequences = np.zeros((batch_size, self.sequence_length, 4)) targets = np.zeros((batch_size, self.n_features)) n_samples_drawn = 0 while n_samples_drawn < batch_size: sample_index = self._randcache[self.mode]["sample_next"] if sample_index == len(self._sample_from_mode[self.mode].indices): self._update_randcache() sample_index = 0 rand_interval_index = \ self._randcache[self.mode]["cache_indices"][sample_index] self._randcache[self.mode]["sample_next"] += 1 interval_info = self.sample_from_intervals[rand_interval_index] interval_length = self.interval_lengths[rand_interval_index] chrom = interval_info[0] position = int( interval_info[1] + random.uniform(0, 1) * interval_length) retrieve_output = self._retrieve(chrom, position) if not retrieve_output: continue seq, seq_targets = retrieve_output sequences[n_samples_drawn, :, :] = seq targets[n_samples_drawn, :] = seq_targets n_samples_drawn += 1 return (sequences, targets)