Source code for xnmt.models.sequence_labelers

from typing import Optional, Set, Sequence, Union
import numbers

import dynet as dy
import numpy as np

from xnmt import batchers, event_trigger, events, inferences, input_readers, reports, sent, vocabs
from xnmt.modelparts import attenders, embedders, scorers, transforms
from xnmt.models import base as models
from xnmt.transducers import recurrent, base as transducers
from xnmt.persistence import serializable_init, Serializable, bare

[docs]class SeqLabeler(models.ConditionedModel, models.GeneratorModel, Serializable, reports.Reportable): """ A simple sequence labeler based on an encoder and an output softmax layer. Args: src_reader: A reader for the source side. trg_reader: A reader for the target side. src_embedder: A word embedder for the input language encoder: An encoder to generate encoded inputs transform: A transform to be applied before making predictions scorer: The class to actually make predictions inference: The inference method used for this model auto_cut_pad: If ``True``, cut or pad target sequences so the match the length of the encoded inputs. If ``False``, an error is thrown if there is a length mismatch. """ yaml_tag = '!SeqLabeler' @events.register_xnmt_handler @serializable_init def __init__(self, src_reader: input_readers.InputReader, trg_reader: input_readers.InputReader, src_embedder: embedders.Embedder = bare(embedders.SimpleWordEmbedder), encoder: transducers.SeqTransducer = bare(recurrent.BiLSTMSeqTransducer), transform: transforms.Transform = bare(transforms.NonLinear), scorer: scorers.Scorer = bare(scorers.Softmax), inference: inferences.Inference = bare(inferences.IndependentOutputInference), auto_cut_pad: bool = False) -> None: super().__init__(src_reader=src_reader, trg_reader=trg_reader) self.src_embedder = src_embedder self.encoder = encoder self.attender = attenders self.transform = transform self.scorer = scorer self.inference = inference self.auto_cut_pad = auto_cut_pad
[docs] def shared_params(self) -> Sequence[Set[str]]: return [{".src_embedder.emb_dim", ".encoder.input_dim"},]
def _encode_src(self, src: Union[sent.Sentence, batchers.Batch]) -> tuple: event_trigger.start_sent(src) embeddings = self.src_embedder.embed_sent(src) encodings = self.encoder.transduce(embeddings) encodings_tensor = encodings.as_tensor() ((hidden_dim, seq_len), batch_size) = encodings.dim() encoding_reshaped = dy.reshape(encodings_tensor, (hidden_dim,), batch_size=batch_size * seq_len) outputs = self.transform.transform(encoding_reshaped) return batch_size, encodings, outputs, seq_len
[docs] def calc_nll(self, src: Union[batchers.Batch, sent.Sentence], trg: Union[batchers.Batch, sent.Sentence]) \ -> dy.Expression: assert batchers.is_batched(src) and batchers.is_batched(trg) batch_size, encodings, outputs, seq_len = self._encode_src(src) if trg.sent_len() != seq_len: if self.auto_cut_pad: trg = self._cut_or_pad_targets(seq_len, trg) else: raise ValueError(f"src/trg length do not match: {seq_len} != {len(trg[0])}") ref_action = np.asarray([trg_sent.words for trg_sent in trg]).reshape((seq_len * batch_size,)) loss_expr_perstep = self.scorer.calc_loss(outputs, batchers.mark_as_batch(ref_action)) # loss_expr_perstep = dy.pickneglogsoftmax_batch(outputs, ref_action) loss_expr_perstep = dy.reshape(loss_expr_perstep, (seq_len,), batch_size=batch_size) if trg.mask: loss_expr_perstep = dy.cmult(loss_expr_perstep, dy.inputTensor(1.0-trg.mask.np_arr.T, batched=True)) loss_expr = dy.sum_elems(loss_expr_perstep) return loss_expr
def _cut_or_pad_targets(self, seq_len: numbers.Integral, trg: batchers.Batch) -> batchers.Batch: old_mask = trg.mask if len(trg[0]) > seq_len: trunc_len = len(trg[0]) - seq_len trg = batchers.mark_as_batch([trg_sent.get_truncated_sent(trunc_len=trunc_len) for trg_sent in trg]) if old_mask: trg.mask = batchers.Mask(np_arr=old_mask.np_arr[:, :-trunc_len]) else: pad_len = seq_len - len(trg[0]) trg = batchers.mark_as_batch([trg_sent.create_padded_sent(pad_len=pad_len) for trg_sent in trg]) if old_mask: trg.mask = np.pad(old_mask.np_arr, pad_width=((0, 0), (0, pad_len)), mode="constant", constant_values=1) return trg
[docs] def generate(self, src: batchers.Batch, normalize_scores: bool = False) -> Sequence[sent.ReadableSentence]: if not batchers.is_batched(src): src = batchers.mark_as_batch([src]) assert src.batch_size() == 1, "batch size > 1 not properly tested" batch_size, encodings, outputs, seq_len = self._encode_src(src) best_words, best_scores = self.scorer.best_k(outputs, k=1, normalize_scores=normalize_scores) best_words = best_words[0, :] score = np.sum(best_scores, axis=1) outputs = [sent.SimpleSentence(words=best_words, idx=src[0].idx, vocab=self.trg_vocab if hasattr(self, "trg_vocab") else None, output_procs=self.trg_reader.output_procs, score=score)] return outputs
[docs] def set_trg_vocab(self, trg_vocab: Optional[vocabs.Vocab] = None) -> None: """ Set target vocab for generating outputs. If not specified, word IDs are generated instead. Args: trg_vocab: target vocab, or None to generate word IDs """ self.trg_vocab = trg_vocab