# Source code for selene_sdk.predict.predict_handlers.logit_score_handler

"""
Handles computing and outputting the log fold change scores
"""
from scipy.special import logit

from .handler import PredictionsHandler

[docs]class LogitScoreHandler(PredictionsHandler):
"""
The logit score handler calculates and records the
difference between logit(alt) and logit(ref) predictions
(logit(alt) - logit(ref)).
For reference, if some event occurs with probability :math:p,
then the log-odds is the logit of p, or

.. math::
\\mathrm{logit}(p) = \\log\\left(\\frac{p}{1 - p}\\right) =
\\log(p) - \\log(1 - p)

Parameters
----------
features : list of str
List of sequence-level features, in the same order that the
model will return its predictions.
columns_for_ids : list of str
Columns in the file that help to identify the input sequence to
which the features data corresponds.
output_path_prefix : str
Path to the file to which Selene will write the absolute difference
scores. The path may contain a filename prefix. Selene will append
logits to the end of the prefix.
output_format : {'tsv', 'hdf5'}
Specify the desired output format. TSV can be specified if you
would like the final file to be easily perused. However, saving
to a TSV file is much slower than saving to an HDF5 file.
output_size : int, optional
The total number of rows in the output. Must be specified when
the output_format is hdf5.
write_mem_limit : int, optional
Default is 1500. Specify the amount of memory you can allocate to
storing model predictions/scores for this particular handler, in MB.
Handler will write to file whenever this memory limit is reached.
write_labels : bool, optional
Default is True. If you initialize multiple write handlers for the
same set of inputs with output format hdf5, set write_label to
False on all handlers except 1 so that only 1 handler writes the
row labels to an output file.

Attributes
----------
needs_base_pred : bool
Whether the handler needs the base (reference) prediction as input
to compute the final output

"""

def __init__(self,
features,
columns_for_ids,
output_path_prefix,
output_format,
output_size=None,
write_mem_limit=1500,
write_labels=True):
"""
Constructs a new LogitScoreHandler object.
"""
super(LogitScoreHandler, self).__init__(
features,
columns_for_ids,
output_path_prefix,
output_format,
output_size=output_size,
write_mem_limit=write_mem_limit,
write_labels=write_labels)

self.needs_base_pred = True
self._results = []
self._samples = []

self._features = features
self._columns_for_ids = columns_for_ids
self._output_path_prefix = output_path_prefix
self._output_format = output_format
self._write_mem_limit = write_mem_limit
self._write_labels = write_labels

self._create_write_handler("logits")

[docs]    def handle_batch_predictions(self,
batch_predictions,
batch_ids,
baseline_predictions):
"""
Handles the model predications for a batch of sequences.

Parameters
----------
batch_predictions : arraylike
The predictions for a batch of sequences. This should have
dimensions of :math:B \\times N (where :math:B is the
size of the mini-batch and :math:N is the number of
features).
batch_ids : list(arraylike)
Batch of sequence identifiers. Each element is arraylike
because it may contain more than one column (written to
file) that together make up a unique identifier for a
sequence.
base_predictions : arraylike
The baseline prediction(s) used to compute the logit scores.
This must either be a vector of :math:N values, or a
matrix of shape :math:B \\times N (where :math:B is
the size of the mini-batch, and :math:N is the number of
features).

"""
baseline_predictions[baseline_predictions == 0] = 1e-24
baseline_predictions[baseline_predictions >= 1] = 0.999999

batch_predictions[batch_predictions == 0] = 1e-24
batch_predictions[batch_predictions >= 1] = 0.999999

logits = logit(batch_predictions) - logit(baseline_predictions)
self._results.append(logits)
self._samples.append(batch_ids)
if self._reached_mem_limit():
self.write_to_file()

[docs]    def write_to_file(self):
"""
Write the stored scores to file.

"""
super().write_to_file()