Source code for xnmt.models.classifiers

import numbers
from typing import Optional, Sequence, Union

import dynet as dy
import numpy as np

from xnmt import batchers, event_trigger, inferences, input_readers, sent
from xnmt.modelparts import transforms
from xnmt.modelparts import scorers
from xnmt.modelparts import embedders
from xnmt.transducers import recurrent
from xnmt.transducers import base as transducers
from xnmt.models import base as models
from xnmt.persistence import serializable_init, Serializable, bare

[docs]class SequenceClassifier(models.ConditionedModel, models.GeneratorModel, Serializable): """ A sequence classifier. Runs embeddings through an encoder, feeds the average over all encoder outputs to a transform and scoring layer. Args: src_reader: A reader for the source side. trg_reader: A reader for the target side. src_embedder: A word embedder for the input language encoder: An encoder to generate encoded inputs inference: how to perform inference transform: A transform performed before the scoring function scorer: A scoring function over the multiple choices """ yaml_tag = '!SequenceClassifier' @serializable_init def __init__(self, src_reader: input_readers.InputReader, trg_reader: input_readers.InputReader, src_embedder: embedders.Embedder = bare(embedders.SimpleWordEmbedder), encoder: transducers.SeqTransducer = bare(recurrent.BiLSTMSeqTransducer), inference=bare(inferences.IndependentOutputInference), transform: transforms.Transform = bare(transforms.NonLinear), scorer: scorers.Scorer = bare(scorers.Softmax)) -> None: super().__init__(src_reader=src_reader, trg_reader=trg_reader) self.src_embedder = src_embedder self.encoder = encoder self.transform = transform self.scorer = scorer self.inference = inference
[docs] def shared_params(self): return [{".src_embedder.emb_dim", ".encoder.input_dim"}, {".encoder.hidden_dim", ".transform.input_dim"}, {".transform.output_dim", ".scorer.input_dim"}]
def _encode_src(self, src): event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src) self.encoder.transduce(embeddings) h = self.encoder.get_final_states()[-1].main_expr() return self.transform.transform(h)
[docs] def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> dy.Expression: h = self._encode_src(src) ids = trg.value if not batchers.is_batched(trg) else batchers.ListBatch([trg_i.value for trg_i in trg]) loss_expr = self.scorer.calc_loss(h, ids) return loss_expr
[docs] def generate(self, src: Union[batchers.Batch, sent.Sentence], normalize_scores: bool = False): if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) h = self._encode_src(src) best_words, best_scores = self.scorer.best_k(h, k=1, normalize_scores=normalize_scores) assert best_words.shape == (1, src.batch_size()) assert best_scores.shape == (1, src.batch_size()) outputs = [] for batch_i in range(src.batch_size()): if src.batch_size() > 1: word = best_words[0, batch_i] score = best_scores[0, batch_i] else: word = best_words[0] score = best_scores[0] outputs.append(sent.ScalarSentence(value=word, score=score)) return outputs
def get_nobp_state(self, state): output_state = state.as_vector() return output_state