Source code for xnmt.modelparts.bridges

from typing import List, Optional, Sequence
import numbers

import dynet as dy

from xnmt import param_initializers
from xnmt.modelparts import transforms
from xnmt.persistence import serializable_init, Serializable, Ref, bare
from xnmt.transducers import base as transducers

[docs]class Bridge(object): """ Responsible for initializing the decoder LSTM, based on the final encoder state """
[docs] def decoder_init(self, enc_final_states: Sequence[transducers.FinalTransducerState]) -> List[dy.Expression]: """ Args: enc_final_states: list of final states for each encoder layer Returns: list of initial hidden and cell expressions for each layer. List indices 0..n-1 hold hidden states, n..2n-1 hold cell states. """ raise NotImplementedError("decoder_init() must be implemented by Bridge subclasses")
[docs]class NoBridge(Bridge, Serializable): """ This bridge initializes the decoder with zero vectors, disregarding the encoder final states. Args: dec_layers: number of decoder layers to initialize dec_dim: hidden dimension of decoder states """ yaml_tag = '!NoBridge' @serializable_init def __init__(self, dec_layers: numbers.Integral = 1, dec_dim: numbers.Integral = Ref("exp_global.default_layer_dim")) -> None: self.dec_layers = dec_layers self.dec_dim = dec_dim
[docs] def decoder_init(self, enc_final_states: Sequence[transducers.FinalTransducerState]) -> List[dy.Expression]: batch_size = enc_final_states[0].main_expr().dim()[1] z = dy.zeros(self.dec_dim, batch_size) return [z] * (self.dec_layers * 2)
[docs]class CopyBridge(Bridge, Serializable): """ This bridge copies final states from the encoder to the decoder initial states. Requires that: - encoder / decoder dimensions match for every layer - num encoder layers >= num decoder layers (if unequal, we disregard final states at the encoder bottom) Args: dec_layers: number of decoder layers to initialize dec_dim: hidden dimension of decoder states """ yaml_tag = '!CopyBridge' @serializable_init def __init__(self, dec_layers: numbers.Integral = 1, dec_dim: numbers.Integral = Ref("exp_global.default_layer_dim")) -> None: self.dec_layers = dec_layers self.dec_dim = dec_dim
[docs] def decoder_init(self, enc_final_states: Sequence[transducers.FinalTransducerState]) -> List[dy.Expression]: if self.dec_layers > len(enc_final_states): raise RuntimeError("CopyBridge requires dec_layers <= len(enc_final_states), but got %s and %s" % (self.dec_layers, len(enc_final_states))) if enc_final_states[0].main_expr().dim()[0][0] != self.dec_dim: raise RuntimeError("CopyBridge requires enc_dim == dec_dim, but got %s and %s" % (enc_final_states[0].main_expr().dim()[0][0], self.dec_dim)) return [enc_state.cell_expr() for enc_state in enc_final_states[-self.dec_layers:]] \ + [enc_state.main_expr() for enc_state in enc_final_states[-self.dec_layers:]]
[docs]class LinearBridge(Bridge, Serializable): """ This bridge does a linear transform of final states from the encoder to the decoder initial states. Requires that num encoder layers >= num decoder layers (if unequal, we disregard final states at the encoder bottom) Args: dec_layers: number of decoder layers to initialize enc_dim: hidden dimension of encoder states dec_dim: hidden dimension of decoder states param_init: how to initialize weight matrices; if None, use ``exp_global.param_init`` bias_init: how to initialize bias vectors; if None, use ``exp_global.bias_init`` projector: linear projection (created automatically) """ yaml_tag = '!LinearBridge'
[docs] def decoder_init(self, enc_final_states: Sequence[transducers.FinalTransducerState]) -> List[dy.Expression]: if self.dec_layers > len(enc_final_states): raise RuntimeError( f"LinearBridge requires dec_layers <= len(enc_final_states), but got {self.dec_layers} and {len(enc_final_states)}") if enc_final_states[0].main_expr().dim()[0][0] != self.enc_dim: raise RuntimeError( f"LinearBridge requires enc_dim == {self.enc_dim}, but got {enc_final_states[0].main_expr().dim()[0][0]}") decoder_init = [self.projector.transform(enc_state.main_expr()) for enc_state in enc_final_states[-self.dec_layers:]] return decoder_init + [dy.tanh(dec) for dec in decoder_init]
@serializable_init def __init__(self, dec_layers: numbers.Integral = 1, enc_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), dec_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)), bias_init: param_initializers.ParamInitializer = Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer)), projector: Optional[transforms.Linear]=None): self.dec_layers = dec_layers self.enc_dim = enc_dim self.dec_dim = dec_dim self.projector = self.add_serializable_component("projector", projector, lambda: transforms.Linear(input_dim=self.enc_dim, output_dim=self.dec_dim, param_init=param_init, bias_init=bias_init))