from typing import Sequence, Union, Optional, Any
import dynet as dy
from xnmt.settings import settings
from xnmt import batchers, event_trigger, events, inferences, input_readers, loss_calculators, losses, reports, utils, \
xnmt_evaluate
from xnmt.eval import metrics
from xnmt.models import base as model_base
from xnmt.persistence import serializable_init, Serializable, Ref, bare
[docs]class EvalTask(object):
"""
An EvalTask is a task that does evaluation and returns one or more EvalScore objects.
"""
def eval(self) -> 'metrics.EvalScore':
raise NotImplementedError("EvalTask.eval() needs to be implemented in child classes")
[docs]class LossEvalTask(EvalTask, Serializable):
"""
A task that does evaluation of the loss function.
Args:
src_file: source file name
ref_file: reference file name
model: generator model to use for inference
batcher: batcher to use
loss_calculator: loss calculator
max_src_len: omit sentences with source length greater than specified number
max_trg_len: omit sentences with target length greater than specified number
max_num_sents: compute loss only for the first n sentences in the given corpus
loss_comb_method: method for combining loss across batch elements ('sum' or 'avg').
desc: description to pass on to computed score objects
"""
yaml_tag = '!LossEvalTask'
@serializable_init
def __init__(self,
src_file: Union[str, Sequence[str]],
ref_file: Optional[str] = None,
model: 'model_base.GeneratorModel' = Ref("model"),
batcher: batchers.Batcher = Ref("train.batcher", default=bare(batchers.SrcBatcher, batch_size=32)),
loss_calculator: loss_calculators.LossCalculator = bare(loss_calculators.MLELoss),
max_src_len: Optional[int] = None,
max_trg_len: Optional[int] = None,
max_num_sents: Optional[int] = None,
loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"),
desc: Any = None) -> None:
self.model = model
self.loss_calculator = loss_calculator
self.src_file = src_file
self.ref_file = ref_file
self.batcher = batcher
self.src_data = None
self.max_src_len = max_src_len
self.max_trg_len = max_trg_len
self.max_num_sents = max_num_sents
self.loss_comb_method = loss_comb_method
self.desc=desc
[docs] def eval(self) -> 'metrics.EvalScore':
"""
Perform evaluation task.
Returns:
Evaluated score
"""
event_trigger.set_train(False)
if self.src_data is None:
self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
input_readers.read_parallel_corpus(src_reader=self.model.src_reader,
trg_reader=self.model.trg_reader,
src_file=self.src_file,
trg_file=self.ref_file,
batcher=self.batcher,
max_num_sents=self.max_num_sents,
max_src_len=self.max_src_len,
max_trg_len=self.max_trg_len)
loss_val = losses.FactoredLossVal()
ref_words_cnt = 0
for src, trg in zip(self.src_batches, self.ref_batches):
with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}):
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
loss = self.loss_calculator.calc_loss(self.model, src, trg)
ref_words_cnt += sum([trg_i.len_unpadded() for trg_i in trg])
loss_val += loss.get_factored_loss_val(comb_method=self.loss_comb_method)
loss_stats = {k: v/ref_words_cnt for k, v in loss_val.items()}
#
return metrics.LossScore(sum(loss_stats.values()),
loss_stats=loss_stats,
num_ref_words=ref_words_cnt,
desc=self.desc)
[docs]class AccuracyEvalTask(EvalTask, Serializable):
"""
A task that does evaluation of some measure of accuracy.
Args:
src_file: path(s) to read source file(s) from
ref_file: path(s) to read reference file(s) from
hyp_file: path to write hypothesis file to
model: generator model to generate hypothesis with
eval_metrics: list of evaluation metrics (list of Evaluator objects or string of comma-separated shortcuts)
inference: inference object
perform_inference: Whether to generate the output or not. One eval task can use an already existing hyp_file
that was generated by the previous eval tasks.
desc: human-readable description passed on to resulting score objects
"""
yaml_tag = '!AccuracyEvalTask'
@serializable_init
@events.register_xnmt_handler
def __init__(self,
src_file: Union[str,Sequence[str]],
ref_file: Union[str,Sequence[str]],
hyp_file: str,
model: 'model_base.GeneratorModel' = Ref("model"),
eval_metrics: Union[str, metrics.Evaluator, Sequence[metrics.Evaluator]] = "bleu",
inference: Optional['inferences.Inference'] = None,
perform_inference: bool = True,
desc: Any = None) -> None:
self.model = model
if isinstance(eval_metrics, str):
eval_metrics = [xnmt_evaluate.eval_shortcuts[shortcut]() for shortcut in eval_metrics.split(",")]
elif not isinstance(eval_metrics, Sequence): eval_metrics = [eval_metrics]
self.eval_metrics = eval_metrics
self.src_file = src_file
self.ref_file = ref_file
self.hyp_file = hyp_file
self.inference = inference or self.model.inference
self.perform_inference = perform_inference
self.desc = desc
def eval(self) -> Sequence[metrics.EvalScore]:
event_trigger.set_train(False)
if issubclass(self.model.__class__, reports.Reportable):
self.model.report_corpus_info({"ref_file": self.ref_file})
if self.perform_inference:
self.inference.perform_inference(generator=self.model,
src_file=self.src_file,
trg_file=self.hyp_file,
ref_file=self.ref_file)
# Evaluate
eval_scores = xnmt_evaluate.xnmt_evaluate(hyp_file=self.hyp_file, ref_file=self.ref_file, desc=self.desc,
evaluators=self.eval_metrics)
return eval_scores
[docs]class DecodingEvalTask(EvalTask, Serializable):
"""
A task that does performs decoding without comparing against a reference.
Args:
src_file: path(s) to read source file(s) from
hyp_file: path to write hypothesis file to
model: generator model to generate hypothesis with
inference: inference object
"""
yaml_tag = '!DecodingEvalTask'
@serializable_init
def __init__(self,
src_file: Union[str,Sequence[str]],
hyp_file: str,
model: 'model_base.GeneratorModel' = Ref("model"),
inference: Optional['inferences.Inference'] = None) -> None:
self.model = model
self.src_file = src_file
self.hyp_file = hyp_file
self.inference = inference or self.model.inference
def eval(self) -> None:
event_trigger.set_train(False)
self.inference.perform_inference(generator=self.model,
src_file=self.src_file,
trg_file=self.hyp_file)
return None