# Source code for selene_sdk.predict.predict_handlers.write_predictions_handler

"""
Handles outputting the model predictions
"""
from .handler import PredictionsHandler

[docs]class WritePredictionsHandler(PredictionsHandler):
"""
Collects batches of model predictions and writes all of them
to file at the end.

Parameters
----------
features : list(str)
List of sequence-level features, in the same order that the
model will return its predictions.
columns_for_ids : list(str)
Columns in the file that help to identify the input sequence to
which the features data corresponds.
output_path_prefix : str
Path to the file to which Selene will write the absolute difference
scores. The path may contain a filename prefix. Selene will append
predictions to the end of the prefix.
output_format : {'tsv', 'hdf5'}
Specify the desired output format. TSV can be specified if you
would like the final file to be easily perused. However, saving
to a TSV file is much slower than saving to an HDF5 file.
output_size : int, optional
The total number of rows in the output. Must be specified when
the output_format is hdf5.
write_mem_limit : int, optional
Default is 1500. Specify the amount of memory you can allocate to
storing model predictions/scores for this particular handler, in MB.
Handler will write to file whenever this memory limit is reached.
write_labels : bool, optional
Default is True. If you initialize multiple write handlers for the
same set of inputs with output format hdf5, set write_label to
False on all handlers except 1 so that only 1 handler writes the
row labels to an output file.

Attributes
----------
needs_base_pred : bool
Whether the handler needs the base (reference) prediction as input
to compute the final output

"""

def __init__(self,
features,
columns_for_ids,
output_path_prefix,
output_format,
output_size=None,
write_mem_limit=1500,
write_labels=True):
"""
Constructs a new WritePredictionsHandler object.
"""
super(WritePredictionsHandler, self).__init__(
features,
columns_for_ids,
output_path_prefix,
output_format,
output_size=output_size,
write_mem_limit=write_mem_limit,
write_labels=write_labels)

self.needs_base_pred = False
self._results = []
self._samples = []

self._features = features
self._columns_for_ids = columns_for_ids
self._output_path_prefix = output_path_prefix
self._output_format = output_format
self._write_mem_limit = write_mem_limit
self._write_labels = write_labels

self._create_write_handler("predictions")

[docs]    def handle_batch_predictions(self,
batch_predictions,
batch_ids):
"""
Handles the predictions for a batch of sequences.

Parameters
----------
batch_predictions : arraylike
The predictions for a batch of sequences. This should have
dimensions of :math:B \\times N (where :math:B is the
size of the mini-batch and :math:N is the number of
features).
batch_ids : list(arraylike)
Batch of sequence identifiers. Each element is arraylike
because it may contain more than one column (written to
file) that together make up a unique identifier for a
sequence.
"""
self._results.append(batch_predictions)
self._samples.append(batch_ids)
if self._reached_mem_limit():
self.write_to_file()

[docs]    def write_to_file(self):
"""
Writes the stored scores to a file.

"""
super().write_to_file()