from typing import Any
import numbers
import dynet as dy
from xnmt import batchers, param_collections
from xnmt.modelparts import bridges, transforms, scorers, embedders
from xnmt.transducers import recurrent
from xnmt.persistence import serializable_init, Serializable, bare, Ref
[docs]class Decoder(object):
"""
A template class to convert a prefix of previously generated words and
a context vector into a probability distribution over possible next words.
"""
def initial_state(self, enc_final_states, ss_expr):
raise NotImplementedError('must be implemented by subclasses')
def add_input(self, dec_state, trg_embedding):
raise NotImplementedError('must be implemented by subclasses')
def calc_loss(self, dec_state, ref_action):
raise NotImplementedError('must be implemented by subclasses')
def calc_score(self, dec_state, action, normalize=False):
raise NotImplementedError('must be implemented by subclasses')
def best_k(self, dec_state, k, normalize_scores=False):
raise NotImplementedError('must be implemented by subclasses')
def sample(self, dec_state, n, temperature=1.0):
raise NotImplementedError('must be implemented by subclasses')
[docs]class DecoderState(object):
"""A state that holds whatever information is required for the decoder.
Child classes must implement the as_vector() method, which will be
used by e.g. the attention mechanism"""
def as_vector(self):
raise NotImplementedError('must be implemented by subclass')
[docs]class AutoRegressiveDecoderState(DecoderState):
"""A state holding all the information needed for AutoRegressiveDecoder
Args:
rnn_state: a DyNet RNN state
context: a DyNet expression
"""
def __init__(self, rnn_state=None, context=None):
self._rnn_state = rnn_state
self._context = context
def as_vector(self):
return self.rnn_state.output()
@property
def rnn_state(self):
return self._rnn_state
@property
def context(self):
return self._context
@context.setter
def context(self, value):
self._context = value
[docs]class AutoRegressiveDecoder(Decoder, Serializable):
"""
Standard autoregressive-decoder.
Args:
input_dim: input dimension
embedder: embedder for target words
input_feeding: whether to activate input feeding
bridge: how to initialize decoder state
rnn: recurrent decoder
transform: a layer of transformation between rnn and output scorer
scorer: the method of scoring the output (usually softmax)
truncate_dec_batches: whether the decoder drops batch elements as soon as these are masked at some time step.
"""
yaml_tag = '!AutoRegressiveDecoder'
@serializable_init
def __init__(self,
input_dim: numbers.Integral = Ref("exp_global.default_layer_dim"),
embedder: embedders.Embedder = bare(embedders.SimpleWordEmbedder),
input_feeding: bool = True,
bridge: bridges.Bridge = bare(bridges.CopyBridge),
rnn: recurrent.UniLSTMSeqTransducer = bare(recurrent.UniLSTMSeqTransducer),
transform: transforms.Transform = bare(transforms.AuxNonLinear),
scorer: scorers.Scorer = bare(scorers.Softmax),
truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None:
self.param_col = param_collections.ParamManager.my_params(self)
self.input_dim = input_dim
self.embedder = embedder
self.truncate_dec_batches = truncate_dec_batches
self.bridge = bridge
self.rnn = rnn
self.transform = transform
self.scorer = scorer
# Input feeding
self.input_feeding = input_feeding
rnn_input_dim = embedder.emb_dim
if input_feeding:
rnn_input_dim += input_dim
assert rnn_input_dim == rnn.total_input_dim, "Wrong input dimension in RNN layer: {} != {}".format(rnn_input_dim, rnn.total_input_dim)
[docs] def shared_params(self):
return [{".embedder.emb_dim", ".rnn.input_dim"},
{".input_dim", ".rnn.decoder_input_dim"},
{".input_dim", ".transform.input_dim"},
{".input_feeding", ".rnn.decoder_input_feeding"},
{".rnn.layers", ".bridge.dec_layers"},
{".rnn.hidden_dim", ".bridge.dec_dim"},
{".rnn.hidden_dim", ".transform.aux_input_dim"},
{".transform.output_dim", ".scorer.input_dim"}]
[docs] def initial_state(self, enc_final_states: Any, ss: Any) -> AutoRegressiveDecoderState:
"""Get the initial state of the decoder given the encoder final states.
Args:
enc_final_states: The encoder final states. Usually but not necessarily an :class:`xnmt.expression_sequence.ExpressionSequence`
ss: first input
Returns:
initial decoder state
"""
rnn_state = self.rnn.initial_state()
rnn_s = self.bridge.decoder_init(enc_final_states)
rnn_state = rnn_state.set_s(rnn_s)
zeros = dy.zeros(self.input_dim) if self.input_feeding else None
ss_expr = self.embedder.embed(ss)
rnn_state = rnn_state.add_input(dy.concatenate([ss_expr, zeros]) if self.input_feeding else ss_expr)
return AutoRegressiveDecoderState(rnn_state=rnn_state, context=zeros)
def _calc_transform(self, mlp_dec_state: AutoRegressiveDecoderState) -> dy.Expression:
h = dy.concatenate([mlp_dec_state.rnn_state.output(), mlp_dec_state.context])
return self.transform.transform(h)
def best_k(self, mlp_dec_state: AutoRegressiveDecoderState, k: numbers.Integral, normalize_scores: bool = False):
h = self._calc_transform(mlp_dec_state)
best_words, best_scores = self.scorer.best_k(h, k, normalize_scores=normalize_scores)
return best_words, best_scores
def sample(self, mlp_dec_state: AutoRegressiveDecoderState, n: numbers.Integral, temperature=1.0):
h = self._calc_transform(mlp_dec_state)
return self.scorer.sample(h, n)
def calc_loss(self, mlp_dec_state, ref_action):
return self.scorer.calc_loss(self._calc_transform(mlp_dec_state), ref_action)