Source code for xnmt.loss_calculators

from typing import List, Optional, Sequence, Union
import numbers

import dynet as dy
import numpy as np

from xnmt import batchers, event_trigger, losses, search_strategies, sent, vocabs
from xnmt.persistence import bare, Ref, Serializable, serializable_init
from xnmt.modelparts import transforms
from xnmt.eval import metrics


[docs]class LossCalculator(object): """ A template class implementing the training strategy and corresponding loss calculation. """ def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batchers.Batch'], trg: Union[sent.Sentence, 'batchers.Batch']) -> losses.FactoredLossExpr: raise NotImplementedError() def remove_eos(self, sequence: Sequence[numbers.Integral], eos_sym=vocabs.Vocab.ES) -> Sequence[numbers.Integral]: try: idx = sequence.index(eos_sym) sequence = sequence[:idx] except ValueError: # NO EOS pass return sequence
[docs]class MLELoss(Serializable, LossCalculator): """ Max likelihood loss calculator. """ yaml_tag = '!MLELoss' @serializable_init def __init__(self) -> None: pass def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batchers.Batch'], trg: Union[sent.Sentence, 'batchers.Batch']) -> losses.FactoredLossExpr: loss = model.calc_nll(src, trg) return losses.FactoredLossExpr({"mle": loss})
[docs]class GlobalFertilityLoss(Serializable, LossCalculator): """ A fertility loss according to Cohn+, 2016. Incorporating Structural Alignment Biases into an Attentional Neural Translation Model https://arxiv.org/pdf/1601.01085.pdf """ yaml_tag = '!GlobalFertilityLoss' @serializable_init def __init__(self) -> None: pass def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batchers.Batch'], trg: Union[sent.Sentence, 'batchers.Batch']) -> losses.FactoredLossExpr: assert hasattr(model, "attender") and hasattr(model.attender, "attention_vecs"), \ "Must be called after MLELoss with models that have attender." masked_attn = model.attender.attention_vecs if trg.mask is not None: trg_mask = 1-(trg.mask.np_arr.transpose()) masked_attn = [dy.cmult(attn, dy.inputTensor(mask, batched=True)) for attn, mask in zip(masked_attn, trg_mask)] loss = self.global_fertility(masked_attn) return losses.FactoredLossExpr({"global_fertility": loss}) def global_fertility(self, a: Sequence[dy.Expression]) -> dy.Expression: return dy.sum_elems(dy.square(1 - dy.esum(a)))
[docs]class CompositeLoss(Serializable, LossCalculator): """ Summing losses from multiple LossCalculator. """ yaml_tag = "!CompositeLoss" @serializable_init def __init__(self, pt_losses:List[LossCalculator], loss_weight: Optional[Sequence[numbers.Real]] = None) -> None: self.pt_losses = pt_losses if loss_weight is None: self.loss_weight = [1.0 for _ in range(len(pt_losses))] else: self.loss_weight = loss_weight assert len(self.loss_weight) == len(pt_losses) def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batchers.Batch'], trg: Union[sent.Sentence, 'batchers.Batch']) -> losses.FactoredLossExpr: total_loss = losses.FactoredLossExpr() for loss, weight in zip(self.pt_losses, self.loss_weight): total_loss.add_factored_loss_expr(loss.calc_loss(model, src, trg) * weight) return total_loss
[docs]class ReinforceLoss(Serializable, LossCalculator): """ Reinforce Loss according to Ranzato+, 2015. SEQUENCE LEVEL TRAINING WITH RECURRENT NEURAL NETWORKS. (This is not the MIXER algorithm) https://arxiv.org/pdf/1511.06732.pdf """ yaml_tag = '!ReinforceLoss' @serializable_init def __init__(self, baseline:Optional[Serializable]=None, evaluation_metric: metrics.SentenceLevelEvaluator = bare(metrics.FastBLEUEvaluator), search_strategy: search_strategies.SearchStrategy = bare(search_strategies.SamplingSearch), inv_eval: bool = True, decoder_hidden_dim: numbers.Integral = Ref("exp_global.default_layer_dim")) -> None: self.inv_eval = inv_eval self.search_strategy = search_strategy self.evaluation_metric = evaluation_metric self.baseline = self.add_serializable_component("baseline", baseline, lambda: transforms.Linear(input_dim=decoder_hidden_dim, output_dim=1)) def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batchers.Batch'], trg: Union[sent.Sentence, 'batchers.Batch']) -> losses.FactoredLossExpr: search_outputs = model.generate_search_output(src, self.search_strategy) sign = -1 if self.inv_eval else 1 total_loss = losses.FactoredLossExpr() for search_output in search_outputs: # Calculate rewards eval_score = [] for trg_i, sample_i in zip(trg, search_output.word_ids): # Removing EOS sample_i = self.remove_eos(sample_i.tolist()) ref_i = trg_i.words[:trg_i.len_unpadded()] score = self.evaluation_metric.evaluate_one_sent(ref_i, sample_i) eval_score.append(sign * score) reward = dy.inputTensor(eval_score, batched=True) # Composing losses loss = losses.FactoredLossExpr() baseline_loss = [] cur_losses = [] for state, mask in zip(search_output.state, search_output.mask): bs_score = self.baseline.transform(dy.nobackprop(state.as_vector())) baseline_loss.append(dy.squared_distance(reward, bs_score)) logsoft = model.decoder.scorer.calc_log_probs(state.as_vector()) loss_i = dy.cmult(logsoft, reward - bs_score) cur_losses.append(dy.cmult(loss_i, dy.inputTensor(mask, batched=True))) loss.add_loss("reinforce", dy.sum_elems(dy.esum(cur_losses))) loss.add_loss("reinf_baseline", dy.sum_elems(dy.esum(baseline_loss))) # Total losses total_loss.add_factored_loss_expr(loss) return loss
[docs]class MinRiskLoss(Serializable, LossCalculator): yaml_tag = '!MinRiskLoss' @serializable_init def __init__(self, evaluation_metric: metrics.Evaluator = bare(metrics.FastBLEUEvaluator), alpha: numbers.Real = 0.005, inv_eval: bool = True, unique_sample: bool = True, search_strategy: search_strategies.SearchStrategy = bare(search_strategies.SamplingSearch)) -> None: # Samples self.alpha = alpha self.evaluation_metric = evaluation_metric self.inv_eval = inv_eval self.unique_sample = unique_sample self.search_strategy = search_strategy def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batchers.Batch'], trg: Union[sent.Sentence, 'batchers.Batch']) -> losses.FactoredLossExpr: batch_size = trg.batch_size() uniques = [set() for _ in range(batch_size)] deltas = [] probs = [] sign = -1 if self.inv_eval else 1 search_outputs = model.generate_search_output(src, self.search_strategy) for search_output in search_outputs: assert len(search_output.word_ids) == 1 assert search_output.word_ids[0].shape == (len(search_output.state),) logprob = [] for word, state in zip(search_output.word_ids[0], search_output.state): lpdist = model.decoder.scorer.calc_log_probs(state.as_vector()) lp = dy.pick(lpdist, word) logprob.append(lp) sample = search_output.word_ids logprob = dy.esum(logprob) * self.alpha # Calculate the evaluation score eval_score = np.zeros(batch_size, dtype=float) mask = np.zeros(batch_size, dtype=float) for j in range(batch_size): ref_j = self.remove_eos(trg[j].words) hyp_j = self.remove_eos(sample[j].tolist()) if self.unique_sample: hash_val = hash(tuple(hyp_j)) if len(hyp_j) == 0 or hash_val in uniques[j]: mask[j] = -1e20 # represents negative infinity continue else: uniques[j].add(hash_val) # Calc evaluation score eval_score[j] = self.evaluation_metric.evaluate_one_sent(ref_j, hyp_j) * sign # Appending the delta and logprob of this sample prob = logprob + dy.inputTensor(mask, batched=True) deltas.append(dy.inputTensor(eval_score, batched=True)) probs.append(prob) sample_prob = dy.softmax(dy.concatenate(probs)) deltas = dy.concatenate(deltas) risk = dy.sum_elems(dy.cmult(sample_prob, deltas)) ### Debug #print(sample_prob.npvalue().transpose()[0]) #print(deltas.npvalue().transpose()[0]) #print("----------------------") ### End debug return losses.FactoredLossExpr({"risk": risk})
[docs]class FeedbackLoss(Serializable, LossCalculator): """ A loss that first calculates a standard loss function, then feeds it back to the model using the model.additional_loss function. Args: child_loss: The loss that will be fed back to the model repeat: Repeat the process multiple times and use the sum of the losses. This is useful when there is some non-determinism (such as sampling in the encoder, etc.) """ yaml_tag = '!FeedbackLoss' @serializable_init def __init__(self, child_loss: LossCalculator = bare(MLELoss), repeat: numbers.Integral = 1) -> None: self.child_loss = child_loss self.repeat = repeat def calc_loss(self, model: 'model_base.ConditionedModel', src: Union[sent.Sentence, 'batcher.Batch'], trg: Union[sent.Sentence, 'batcher.Batch']) -> losses.FactoredLossExpr: loss_builder = losses.FactoredLossExpr() for _ in range(self.repeat): standard_loss = self.child_loss.calc_loss(model, src, trg) additional_loss = event_trigger.calc_additional_loss(trg, model, standard_loss) loss_builder.add_factored_loss_expr(standard_loss) loss_builder.add_factored_loss_expr(additional_loss) return loss_builder