Source code for xnmt.models.base

from typing import Optional, Sequence, Union

import dynet as dy

from xnmt import batchers, input_readers, sent
from xnmt.persistence import Serializable, serializable_init

[docs]class TrainableModel(object): """ A template class for a basic trainable model, implementing a loss function. """
[docs] def calc_nll(self, *args, **kwargs) -> dy.Expression: """Calculate loss based on input-output pairs. Losses are accumulated only across unmasked timesteps in each batch element. Arguments are to be defined by subclasses Returns: A (possibly batched) expression representing the loss. """
[docs]class UnconditionedModel(TrainableModel): """ A template class for trainable model that computes target losses without conditioning on other inputs. Args: trg_reader: target reader """ def __init__(self, trg_reader: input_readers.InputReader) -> None: self.trg_reader = trg_reader
[docs] def calc_nll(self, trg: Union[batchers.Batch, sent.Sentence]) -> dy.Expression: """Calculate loss based on target inputs. Losses are accumulated only across unmasked timesteps in each batch element. Args: trg: The target, a sentence or a batch of sentences. Returns: A (possibly batched) expression representing the loss. """
[docs]class ConditionedModel(TrainableModel): """ A template class for a trainable model that computes target losses conditioned on a source input. Args: src_reader: source reader trg_reader: target reader """ def __init__(self, src_reader: input_readers.InputReader, trg_reader: input_readers.InputReader) -> None: self.src_reader = src_reader self.trg_reader = trg_reader
[docs] def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> dy.Expression: """Calculate loss based on input-output pairs. Losses are accumulated only across unmasked timesteps in each batch element. Args: src: The source, a sentence or a batch of sentences. trg: The target, a sentence or a batch of sentences. Returns: A (possibly batched) expression representing the loss. """
[docs]class GeneratorModel(object): """ A template class for models that can perform inference to generate some kind of output. Args: src_reader: source input reader trg_reader: an optional target input reader, needed in some cases such as n-best scoring """ def __init__(self, src_reader: input_readers.InputReader, trg_reader: Optional[input_readers.InputReader] = None) \ -> None: self.src_reader = src_reader self.trg_reader = trg_reader
[docs] def generate(self, src: batchers.Batch, *args, **kwargs) -> Sequence[sent.ReadableSentence]: """ Generate outputs. Args: src: batch of source-side inputs *args: **kwargs: Further arguments to be specified by subclasses Returns: output objects """ raise NotImplementedError("must be implemented by subclasses")
[docs]class CascadeGenerator(GeneratorModel, Serializable): """ A cascade that chains several generator models. This generator does not support calling ``generate()`` directly. Instead, it's sub-generators should be accessed and used to generate outputs one by one. Args: generators: list of generators """ yaml_tag = '!CascadeGenerator' @serializable_init def __init__(self, generators: Sequence[GeneratorModel]) -> None: super().__init__(src_reader = generators[0].src_reader, trg_reader = generators[-1].trg_reader) self.generators = generators
[docs] def generate(self, *args, **kwargs) -> Sequence[sent.ReadableSentence]: raise ValueError("cannot call CascadeGenerator.generate() directly; access the sub-generators instead.")