Source code for xnmt.modelparts.attenders

import math
import numbers

import numpy as np
import dynet as dy

from xnmt import logger
from xnmt import batchers, expression_seqs, events, param_collections, param_initializers
from xnmt.persistence import serializable_init, Serializable, Ref, bare

[docs]class Attender(object): """ A template class for functions implementing attention. """
[docs] def init_sent(self, sent: expression_seqs.ExpressionSequence) -> None: """Args: sent: the encoder states, aka keys and values. Usually but not necessarily an :class:`expression_seqs.ExpressionSequence` """ raise NotImplementedError('init_sent must be implemented for Attender subclasses')
[docs] def calc_attention(self, state: dy.Expression) -> dy.Expression: """ Compute attention weights. Args: state: the current decoder state, aka query, for which to compute the weights. Returns: DyNet expression containing normalized attention scores """ raise NotImplementedError('calc_attention must be implemented for Attender subclasses')
[docs] def calc_context(self, state: dy.Expression, attention: dy.Expression = None) -> dy.Expression: """ Compute weighted sum. Args: state: the current decoder state, aka query, for which to compute the weighted sum. attention: the attention vector to use. if not given it is calculated from the state. """ attention = attention or self.calc_attention(state) I = self.curr_sent.as_tensor() return I * attention
[docs]class MlpAttender(Attender, Serializable): """ Implements the attention model of Bahdanau et. al (2014) Args: input_dim: input dimension state_dim: dimension of state inputs hidden_dim: hidden MLP dimension param_init: how to initialize weight matrices bias_init: how to initialize bias vectors truncate_dec_batches: whether the decoder drops batch elements as soon as these are masked at some time step. """ yaml_tag = '!MlpAttender' @serializable_init def __init__(self, input_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), state_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), hidden_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)), bias_init: param_initializers.ParamInitializer = Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer)), truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None: self.input_dim = input_dim self.state_dim = state_dim self.hidden_dim = hidden_dim self.truncate_dec_batches = truncate_dec_batches param_collection = param_collections.ParamManager.my_params(self) self.pW = param_collection.add_parameters((hidden_dim, input_dim), init=param_init.initializer((hidden_dim, input_dim))) self.pV = param_collection.add_parameters((hidden_dim, state_dim), init=param_init.initializer((hidden_dim, state_dim))) self.pb = param_collection.add_parameters((hidden_dim,), init=bias_init.initializer((hidden_dim,))) self.pU = param_collection.add_parameters((1, hidden_dim), init=param_init.initializer((1, hidden_dim))) self.curr_sent = None self.attention_vecs = None self.WI = None
[docs] def init_sent(self, sent: expression_seqs.ExpressionSequence) -> None: self.attention_vecs = [] self.curr_sent = sent I = self.curr_sent.as_tensor() W = dy.parameter(self.pW) b = dy.parameter(self.pb) self.WI = dy.affine_transform([b, W, I]) wi_dim = self.WI.dim() # TODO(philip30): dynet affine transform bug, should be fixed upstream # if the input size is "1" then the last dimension will be dropped. if len(wi_dim[0]) == 1: self.WI = dy.reshape(self.WI, (wi_dim[0][0], 1), batch_size=wi_dim[1])
[docs] def calc_attention(self, state: dy.Expression) -> dy.Expression: V = dy.parameter(self.pV) U = dy.parameter(self.pU) WI = self.WI curr_sent_mask = self.curr_sent.mask if self.truncate_dec_batches: if curr_sent_mask: state, WI, curr_sent_mask = batchers.truncate_batches(state, WI, curr_sent_mask) else: state, WI = batchers.truncate_batches(state, WI) h = dy.tanh(dy.colwise_add(WI, V * state)) scores = dy.transpose(U * h) if curr_sent_mask is not None: scores = curr_sent_mask.add_to_tensor_expr(scores, multiplicator = -100.0) normalized = dy.softmax(scores) self.attention_vecs.append(normalized) return normalized
[docs]class DotAttender(Attender, Serializable): """ Implements dot product attention of Also (optionally) perform scaling of Args: scale: whether to perform scaling truncate_dec_batches: currently unsupported """ yaml_tag = '!DotAttender' @serializable_init def __init__(self, scale: bool = True, truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None: if truncate_dec_batches: raise NotImplementedError("truncate_dec_batches not yet implemented for DotAttender") self.curr_sent = None self.scale = scale self.attention_vecs = []
[docs] def init_sent(self, sent: expression_seqs.ExpressionSequence) -> None: self.curr_sent = sent self.attention_vecs = [] self.I = dy.transpose(self.curr_sent.as_tensor())
[docs] def calc_attention(self, state: dy.Expression) -> dy.Expression: scores = self.I * state if self.scale: scores /= math.sqrt(state.dim()[0][0]) if self.curr_sent.mask is not None: scores = self.curr_sent.mask.add_to_tensor_expr(scores, multiplicator = -100.0) normalized = dy.softmax(scores) self.attention_vecs.append(normalized) return normalized
[docs]class BilinearAttender(Attender, Serializable): """ Implements a bilinear attention, equivalent to the 'general' linear attention of Args: input_dim: input dimension; if None, use exp_global.default_layer_dim state_dim: dimension of state inputs; if None, use exp_global.default_layer_dim param_init: how to initialize weight matrices; if None, use ``exp_global.param_init`` truncate_dec_batches: currently unsupported """ yaml_tag = '!BilinearAttender' @serializable_init def __init__(self, input_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), state_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)), truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None: if truncate_dec_batches: raise NotImplementedError("truncate_dec_batches not yet implemented for BilinearAttender") self.input_dim = input_dim self.state_dim = state_dim param_collection = param_collections.ParamManager.my_params(self) self.pWa = param_collection.add_parameters((input_dim, state_dim), init=param_init.initializer((input_dim, state_dim))) self.curr_sent = None
[docs] def init_sent(self, sent: expression_seqs.ExpressionSequence) -> None: self.curr_sent = sent self.attention_vecs = [] self.I = self.curr_sent.as_tensor()
# TODO(philip30): Please apply masking here
[docs] def calc_attention(self, state: dy.Expression) -> dy.Expression: logger.warning("BilinearAttender does currently not do masking, which may harm training results.") Wa = dy.parameter(self.pWa) scores = (dy.transpose(state) * Wa) * self.I normalized = dy.softmax(scores) self.attention_vecs.append(normalized) return dy.transpose(normalized)
[docs]class LatticeBiasedMlpAttender(MlpAttender, Serializable): """ Modified MLP attention, where lattices are assumed as input and the attention is biased toward confident nodes. Args: input_dim: input dimension state_dim: dimension of state inputs hidden_dim: hidden MLP dimension param_init: how to initialize weight matrices bias_init: how to initialize bias vectors truncate_dec_batches: whether the decoder drops batch elements as soon as these are masked at some time step. """ yaml_tag = '!LatticeBiasedMlpAttender' @events.register_xnmt_handler @serializable_init def __init__(self, input_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), state_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), hidden_dim: numbers.Integral = Ref("exp_global.default_layer_dim"), param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)), bias_init: param_initializers.ParamInitializer = Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer)), truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None: super().__init__(input_dim=input_dim, state_dim=state_dim, hidden_dim=hidden_dim, param_init=param_init, bias_init=bias_init, truncate_dec_batches=truncate_dec_batches) @events.handle_xnmt_event def on_start_sent(self, src): self.cur_sent_bias = np.full((src.sent_len(), 1, src.batch_size()), -1e10) for batch_i, lattice_batch_elem in enumerate(src): for node_id in lattice_batch_elem.nodes: self.cur_sent_bias[node_id, 0, batch_i] = lattice_batch_elem.graph[node_id].marginal_log_prob self.cur_sent_bias_expr = None
[docs] def calc_attention(self, state: dy.Expression) -> dy.Expression: V = dy.parameter(self.pV) U = dy.parameter(self.pU) WI = self.WI curr_sent_mask = self.curr_sent.mask if self.truncate_dec_batches: if curr_sent_mask: state, WI, curr_sent_mask = batchers.truncate_batches(state, WI, curr_sent_mask) else: state, WI = batchers.truncate_batches(state, WI) h = dy.tanh(dy.colwise_add(WI, V * state)) scores = dy.transpose(U * h) if curr_sent_mask is not None: scores = curr_sent_mask.add_to_tensor_expr(scores, multiplicator = -1e10) if self.cur_sent_bias_expr is None: self.cur_sent_bias_expr = dy.inputTensor(self.cur_sent_bias, batched=True) normalized = dy.softmax(scores + self.cur_sent_bias_expr) self.attention_vecs.append(normalized) return normalized