Source code for selene_sdk.samplers.intervals_sampler

"""
This module provides the IntervalsSampler class and supporting
methods.
"""
from collections import namedtuple
import logging
import random

import numpy as np

from .online_sampler import OnlineSampler
from ..utils import get_indices_and_probabilities

logger = logging.getLogger(__name__)

SampleIndices = namedtuple(
"SampleIndices", ["indices", "weights"])
"""
A tuple containing the indices for some samples, and a weight to
allot to each index when randomly drawing from them.

Parameters
----------
indices : list(int)
The numeric index of each sample.
weights : list(float)
The amount of weight assigned to each sample.

Attributes
----------
indices : list(int)
The numeric index of each sample.
weights : list(float)
The amount of weight assigned to each sample.

"""

# @TODO: Extend this class to work with stranded data.
[docs]class IntervalsSampler(OnlineSampler): """ Draws samples from pre-specified windows in the reference sequence. Parameters ---------- reference_sequence : selene_sdk.sequences.Sequence A reference sequence from which to create examples. target_path : str Path to tabix-indexed, compressed BED file (*.bed.gz) of genomic coordinates mapped to the genomic features we want to predict. features : list(str) List of distinct features that we aim to predict. intervals_path : str The path to the file that contains the intervals to sample from. In this file, each interval should occur on a separate line. sample_negative : bool, optional Default is False. This tells the sampler whether negative examples (i.e. with no positive labels) should be drawn when generating samples. If True, both negative and positive samples will be drawn. If False, only samples with at least one positive label will be drawn. seed : int, optional Default is 436. Sets the random seed for sampling. validation_holdout : list(str) or float, optional Default is ['chr6', 'chr7']. Holdout can be regional or proportional. If regional, expects a list (e.g. ['X', 'Y']). Regions must match those specified in the first column of the tabix-indexed BED file. If proportional, specify a percentage between (0.0, 1.0). Typically 0.10 or 0.20. test_holdout : list(str) or float, optional Default is ['chr8', 'chr9']. See documentation for validation_holdout for additional information. sequence_length : int, optional Default is 1000. Model is trained on sequences of sequence_length where genomic features are annotated to the center regions of these sequences. center_bin_to_predict : int, optional Default is 200. Query the tabix-indexed file for a region of length center_bin_to_predict. feature_thresholds : float [0.0, 1.0] or None, optional Default is 0.5. The feature_threshold to pass to the GenomicFeatures object. mode : {'train', 'validate', 'test'} Default is 'train'. The mode to run the sampler in. save_datasets : list of str Default is ["test"]. The list of modes for which we should save the sampled data to file. output_dir : str or None, optional Default is None. The path to the directory where we should save sampled examples for a mode. If save_datasets is a non-empty list, output_dir must be specified. If the path in output_dir does not exist it will be created automatically. Attributes ---------- reference_sequence : selene_sdk.sequences.Sequence The reference sequence that examples are created from. target : selene_sdk.targets.Target The selene_sdk.targets.Target object holding the features that we would like to predict. sample_from_intervals : list(tuple(str, int, int)) A list of coordinates that specify the intervals we can draw samples from. interval_lengths : list(int) A list of the lengths of the intervals that we can draw samples from. The probability that we will draw a sample from an interval is a function of that interval's length and the length of all other intervals. sample_negative : bool Whether negative examples (i.e. with no positive label) should be drawn when generating samples. If True, both negative and positive samples will be drawn. If False, only samples with at least one positive label will be drawn. validation_holdout : list(str) or float The samples to hold out for validating model performance. These can be "regional" or "proportional". If regional, this is a list of region names (e.g. ['chrX', 'chrY']). These Regions must match those specified in the first column of the tabix-indexed BED file. If proportional, this is the fraction of total samples that will be held out. test_holdout : list(str) or float The samples to hold out for testing model performance. See the documentation for validation_holdout for more details. sequence_length : int The length of the sequences to train the model on. bin_radius : int From the center of the sequence, the radius in which to detect a feature annotation in order to include it as a sample's label. surrounding_sequence_radius : int The length of sequence falling outside of the feature detection bin (i.e. bin_radius) center, but still within the sequence_length. modes : list(str) The list of modes that the sampler can be run in. mode : str The current mode that the sampler is running in. Must be one of the modes listed in modes. """ def __init__(self, reference_sequence, target_path, features, intervals_path, sample_negative=False, seed=436, validation_holdout=['chr6', 'chr7'], test_holdout=['chr8', 'chr9'], sequence_length=1000, center_bin_to_predict=200, feature_thresholds=0.5, mode="train", save_datasets=["test"], output_dir=None): """ Constructs a new IntervalsSampler object. """ super(IntervalsSampler, self).__init__( reference_sequence, target_path, features, seed=seed, validation_holdout=validation_holdout, test_holdout=test_holdout, sequence_length=sequence_length, center_bin_to_predict=center_bin_to_predict, feature_thresholds=feature_thresholds, mode=mode, save_datasets=save_datasets, output_dir=output_dir) self._sample_from_mode = {} self._randcache = {} for mode in self.modes: self._sample_from_mode[mode] = None self._randcache[mode] = {"cache_indices": None, "sample_next": 0} self.sample_from_intervals = [] self.interval_lengths = [] if self._holdout_type == "chromosome": self._partition_dataset_chromosome(intervals_path) else: self._partition_dataset_proportion(intervals_path) for mode in self.modes: self._update_randcache(mode=mode) self.sample_negative = sample_negative def _partition_dataset_proportion(self, intervals_path): """ When holdout sets are created by randomly sampling a proportion of the data, this method is used to divide the data into train/test/validate subsets. Parameters ---------- intervals_path : str The path to the file that contains the intervals to sample from. In this file, each interval should occur on a separate line. """ with open(intervals_path, 'r') as file_handle: for line in file_handle: cols = line.strip().split('\t') chrom = cols[0] start = int(cols[1]) end = int(cols[2]) self.sample_from_intervals.append((chrom, start, end)) self.interval_lengths.append(end - start) n_intervals = len(self.sample_from_intervals) # all indices in the intervals list are shuffled select_indices = list(range(n_intervals)) np.random.shuffle(select_indices) # the first section of indices is used as the validation set n_indices_validate = int(n_intervals * self.validation_holdout) val_indices, val_weights = get_indices_and_probabilities( self.interval_lengths, select_indices[:n_indices_validate]) self._sample_from_mode["validate"] = SampleIndices( val_indices, val_weights) if self.test_holdout: # if applicable, the second section of indices is used as the # test set n_indices_test = int(n_intervals * self.test_holdout) test_indices_end = n_indices_test + n_indices_validate test_indices, test_weights = get_indices_and_probabilities( self.interval_lengths, select_indices[n_indices_validate:test_indices_end]) self._sample_from_mode["test"] = SampleIndices( test_indices, test_weights) # remaining indices are for the training set tr_indices, tr_weights = get_indices_and_probabilities( self.interval_lengths, select_indices[test_indices_end:]) self._sample_from_mode["train"] = SampleIndices( tr_indices, tr_weights) else: # remaining indices are for the training set tr_indices, tr_weights = get_indices_and_probabilities( self.interval_lengths, select_indices[n_indices_validate:]) self._sample_from_mode["train"] = SampleIndices( tr_indices, tr_weights) def _partition_dataset_chromosome(self, intervals_path): """ When holdout sets are created by selecting all samples from a specified region (e.g. a chromosome) this method is used to divide the data into train/test/validate subsets. Parameters ---------- intervals_path : str The path to the file that contains the intervals to sample from. In this file, each interval should occur on a separate line. """ for mode in self.modes: self._sample_from_mode[mode] = SampleIndices([], []) with open(intervals_path, 'r') as file_handle: for index, line in enumerate(file_handle): cols = line.strip().split('\t') chrom = cols[0] start = int(cols[1]) end = int(cols[2]) if chrom in self.validation_holdout: self._sample_from_mode["validate"].indices.append( index) elif self.test_holdout and chrom in self.test_holdout: self._sample_from_mode["test"].indices.append( index) else: self._sample_from_mode["train"].indices.append( index) self.sample_from_intervals.append((chrom, start, end)) self.interval_lengths.append(end - start) for mode in self.modes: sample_indices = self._sample_from_mode[mode].indices indices, weights = get_indices_and_probabilities( self.interval_lengths, sample_indices) self._sample_from_mode[mode] = \ self._sample_from_mode[mode]._replace( indices=indices, weights=weights) def _retrieve(self, chrom, position): """ Retrieves samples around a position in the reference_sequence. Parameters ---------- chrom : str The name of the region (e.g. "chrX", "YFP") position : int The position in the query region that we will search around for samples. Returns ------- retrieved_seq, retrieved_targets : \ tuple(numpy.ndarray, list(numpy.ndarray)) A tuple containing the numeric representation of the sequence centered at the query position, as well as a list of samples within this region that met the filtering standards. """ bin_start = position - self._start_radius bin_end = position + self._end_radius retrieved_targets = self.target.get_feature_data( chrom, bin_start, bin_end) if not self.sample_negative and np.sum(retrieved_targets) == 0: logger.info("No features found in region surrounding " "region \"{0}\" position {1}. Sampling again.".format( chrom, position)) return None window_start = bin_start - self.surrounding_sequence_radius window_end = bin_end + self.surrounding_sequence_radius strand = self.STRAND_SIDES[random.randint(0, 1)] retrieved_seq = \ self.reference_sequence.get_encoding_from_coords( chrom, window_start, window_end, strand) if retrieved_seq.shape[0] == 0: logger.info("Full sequence centered at region \"{0}\" position " "{1} could not be retrieved. Sampling again.".format( chrom, position)) return None elif np.sum(retrieved_seq) / float(retrieved_seq.shape[0]) < 0.60: logger.info("Over 30% of the bases in the sequence centered " "at region \"{0}\" position {1} are ambiguous ('N'). " "Sampling again.".format(chrom, position)) return None if self.mode in self._save_datasets: feature_indices = ';'.join( [str(f) for f in np.nonzero(retrieved_targets)[0]]) self._save_datasets[self.mode].append( [chrom, window_start, window_end, strand, feature_indices]) if len(self._save_datasets[self.mode]) > 200000: self.save_dataset_to_file(self.mode) return (retrieved_seq, retrieved_targets) def _update_randcache(self, mode=None): """ Updates the cache of indices of intervals. This allows us to randomly sample from our data without having to use a fixed-point approach or keeping all labels in memory. Parameters ---------- mode : str or None, optional Default is None. The mode that these samples should be used for. See selene_sdk.samplers.IntervalsSampler.modes for more information. """ if not mode: mode = self.mode self._randcache[mode]["cache_indices"] = np.random.choice( self._sample_from_mode[mode].indices, size=len(self._sample_from_mode[mode].indices), replace=True, p=self._sample_from_mode[mode].weights) self._randcache[mode]["sample_next"] = 0
[docs] def sample(self, batch_size=1): """ Randomly draws a mini-batch of examples and their corresponding labels. Parameters ---------- batch_size : int, optional Default is 1. The number of examples to include in the mini-batch. Returns ------- sequences, targets : tuple(numpy.ndarray, numpy.ndarray) A tuple containing the numeric representation of the sequence examples and their corresponding labels. The shape of sequences will be :math:B \\times L \\times N, where :math:B is batch_size, :math:L is the sequence length, and :math:N is the size of the sequence type's alphabet. The shape of targets will be :math:B \\times F, where :math:F is the number of features. """ sequences = np.zeros((batch_size, self.sequence_length, 4)) targets = np.zeros((batch_size, self.n_features)) n_samples_drawn = 0 while n_samples_drawn < batch_size: sample_index = self._randcache[self.mode]["sample_next"] if sample_index == len(self._sample_from_mode[self.mode].indices): self._update_randcache() sample_index = 0 rand_interval_index = \ self._randcache[self.mode]["cache_indices"][sample_index] self._randcache[self.mode]["sample_next"] += 1 interval_info = self.sample_from_intervals[rand_interval_index] interval_length = self.interval_lengths[rand_interval_index] chrom = interval_info[0] position = int( interval_info[1] + random.uniform(0, 1) * interval_length) retrieve_output = self._retrieve(chrom, position) if not retrieve_output: continue seq, seq_targets = retrieve_output sequences[n_samples_drawn, :, :] = seq targets[n_samples_drawn, :] = seq_targets n_samples_drawn += 1 return (sequences, targets)