Source code for xnmt.transducers.pyramidal

from typing import Any, List, Sequence, Union
import numbers

import dynet as dy

from xnmt.transducers import recurrent
from xnmt import expression_seqs, events
from xnmt.persistence import serializable_init, Serializable, Ref
from xnmt.transducers import base as transducers


[docs]class PyramidalLSTMSeqTransducer(transducers.SeqTransducer, Serializable): """ Builder for pyramidal RNNs that delegates to ``UniLSTMSeqTransducer`` objects and wires them together. See https://arxiv.org/abs/1508.01211 Every layer (except the first) reduces sequence length by the specified factor. Args: layers: number of layers input_dim: input dimension hidden_dim: hidden dimension downsampling_method: how to perform downsampling (concat|skip) reduce_factor: integer, or list of ints (different skip for each layer) dropout: dropout probability; if None, use exp_global.dropout builder_layers: set automatically """ yaml_tag = '!PyramidalLSTMSeqTransducer' @events.register_xnmt_handler @serializable_init def __init__(self, layers: numbers.Integral = 1, input_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), hidden_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), downsampling_method: str = "concat", reduce_factor: Union[numbers.Integral, Sequence[numbers.Integral]] = 2, dropout: float = Ref("exp_global.dropout", default=0.0), builder_layers: Any = None): self.dropout = dropout assert layers > 0 assert hidden_dim % 2 == 0 assert type(reduce_factor)==int or (type(reduce_factor)==list and len(reduce_factor)==layers-1) assert downsampling_method in ["concat", "skip"] self.downsampling_method = downsampling_method self.reduce_factor = reduce_factor self.input_dim = input_dim self.hidden_dim = hidden_dim self.builder_layers = self.add_serializable_component("builder_layers", builder_layers, lambda: self.make_builder_layers(input_dim, hidden_dim, layers, dropout, downsampling_method, reduce_factor)) def make_builder_layers(self, input_dim, hidden_dim, layers, dropout, downsampling_method, reduce_factor): builder_layers = [] f = recurrent.UniLSTMSeqTransducer(input_dim=input_dim, hidden_dim=hidden_dim // 2, dropout=dropout) b = recurrent.UniLSTMSeqTransducer(input_dim=input_dim, hidden_dim=hidden_dim // 2, dropout=dropout) builder_layers.append([f, b]) for _ in range(layers - 1): layer_input_dim = hidden_dim if downsampling_method=="skip" else hidden_dim*reduce_factor f = recurrent.UniLSTMSeqTransducer(input_dim=layer_input_dim, hidden_dim=hidden_dim // 2, dropout=dropout) b = recurrent.UniLSTMSeqTransducer(input_dim=layer_input_dim, hidden_dim=hidden_dim // 2, dropout=dropout) builder_layers.append([f, b]) return builder_layers @events.handle_xnmt_event def on_start_sent(self, src): self._final_states = None
[docs] def get_final_states(self) -> List[transducers.FinalTransducerState]: return self._final_states
def _reduce_factor_for_layer(self, layer_i): if layer_i >= len(self.builder_layers)-1: return 1 elif type(self.reduce_factor)==int: return self.reduce_factor else: return self.reduce_factor[layer_i]
[docs] def transduce(self, es: expression_seqs.ExpressionSequence) -> expression_seqs.ExpressionSequence: """ returns the list of output Expressions obtained by adding the given inputs to the current state, one by one, to both the forward and backward RNNs, and concatenating. Args: es: an ExpressionSequence """ es_list = [es] for layer_i, (fb, bb) in enumerate(self.builder_layers): reduce_factor = self._reduce_factor_for_layer(layer_i) if es_list[0].mask is None: mask_out = None else: mask_out = es_list[0].mask.lin_subsampled(reduce_factor) if self.downsampling_method=="concat" and len(es_list[0]) % reduce_factor != 0: raise ValueError(f"For 'concat' subsampling, sequence lengths must be multiples of the total reduce factor, " f"but got sequence length={len(es_list[0])} for reduce_factor={reduce_factor}. " f"Set Batcher's pad_src_to_multiple argument accordingly.") fs = fb.transduce(es_list) bs = bb.transduce([expression_seqs.ReversedExpressionSequence(es_item) for es_item in es_list]) if layer_i < len(self.builder_layers) - 1: if self.downsampling_method=="skip": es_list = [expression_seqs.ExpressionSequence(expr_list=fs[::reduce_factor], mask=mask_out), expression_seqs.ExpressionSequence(expr_list=bs[::reduce_factor][::-1], mask=mask_out)] elif self.downsampling_method=="concat": es_len = len(es_list[0]) es_list_fwd = [] es_list_bwd = [] for i in range(0, es_len, reduce_factor): for j in range(reduce_factor): if i==0: es_list_fwd.append([]) es_list_bwd.append([]) es_list_fwd[j].append(fs[i+j]) es_list_bwd[j].append(bs[len(es_list[0])-reduce_factor+j-i]) es_list = [expression_seqs.ExpressionSequence(expr_list=es_list_fwd[j], mask=mask_out) for j in range(reduce_factor)] + \ [expression_seqs.ExpressionSequence(expr_list=es_list_bwd[j], mask=mask_out) for j in range(reduce_factor)] else: raise RuntimeError(f"unknown downsampling_method {self.downsampling_method}") else: # concat final outputs ret_es = expression_seqs.ExpressionSequence( expr_list=[dy.concatenate([f, b]) for f, b in zip(fs, expression_seqs.ReversedExpressionSequence(bs))], mask=mask_out) self._final_states = [transducers.FinalTransducerState(dy.concatenate([fb.get_final_states()[0].main_expr(), bb.get_final_states()[0].main_expr()]), dy.concatenate([fb.get_final_states()[0].cell_expr(), bb.get_final_states()[0].cell_expr()])) \ for (fb, bb) in self.builder_layers] return ret_es