Source code for xnmt.train.regimens

import contextlib
from typing import Callable, Dict, Optional, Sequence, Union
from collections import OrderedDict
import numbers

from xnmt.settings import settings
import numpy as np
import dynet as dy


from xnmt import batchers, event_trigger, loss_calculators, loss_trackers, losses, optimizers, param_collections, utils
from xnmt.models import base as models
from xnmt.persistence import serializable_init, Serializable, bare, Ref
from xnmt.eval import tasks as eval_tasks
from xnmt.train import tasks as train_tasks


[docs]class TrainingRegimen(object): """ A training regimen is a class that implements a training loop. """
[docs] def run_training(self, save_fct: Callable) -> None: """ Run training steps in a loop until stopping criterion is reached. Args: save_fct: function to be invoked to save a model at dev checkpoints """ raise NotImplementedError("")
[docs] def backward(self, loss: dy.Expression, dynet_profiling: numbers.Integral) -> None: """ Perform backward pass to accumulate gradients. Args: loss: Result of self.training_step(...) dynet_profiling: if > 0, print the computation graph """ if dynet_profiling and dynet_profiling > 0: dy.print_text_graphviz() loss.backward()
[docs] def update(self, trainer: optimizers.XnmtOptimizer) -> None: """ Update DyNet weights using the given optimizer. Args: trainer: DyNet trainer """ trainer.update()
[docs]class SimpleTrainingRegimen(train_tasks.SimpleTrainingTask, TrainingRegimen, Serializable): """ Args: model: the model src_file: the source training file trg_file: the target training file dev_every: dev checkpoints every n sentences (0 for only after epoch) dev_zero: if True, add a checkpoint before training loop is entered (useful with pretrained models). batcher: Type of batcher loss_calculator: The method for calculating the loss. trainer: Trainer object, default is SGD with learning rate 0.1 run_for_epochs: lr_decay: lr_decay_times: Early stopping after decaying learning rate a certain number of times patience: apply LR decay after dev scores haven't improved over this many checkpoints initial_patience: if given, allows adjusting patience for the first LR decay dev_tasks: A list of tasks to use during the development stage. dev_combinator: A formula to combine together development scores into a single score to choose whether to perform learning rate decay, etc. e.g. 'x[0]-x[1]' would say that the first dev task score minus the second dev task score is our measure of how well we're doing. If not specified, only the score from the first dev task will be used. restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf) reload_command: Command to change the input data after each epoch. --epoch EPOCH_NUM will be appended to the command. To just reload the data after each epoch set the command to ``True``. name: will be prepended to log outputs if given sample_train_sents: max_num_train_sents: max_src_len: max_trg_len: loss_comb_method: method for combining loss across batch elements (``sum`` or ``avg``). update_every: simulate large-batch training by accumulating gradients over several steps before updating parameters commandline_args: """ yaml_tag = '!SimpleTrainingRegimen' @serializable_init def __init__(self, model: models.ConditionedModel = Ref("model"), src_file: Union[None, str, Sequence[str]] = None, trg_file: Optional[str] = None, dev_every: numbers.Integral = 0, dev_zero: bool = False, batcher: batchers.Batcher = bare(batchers.SrcBatcher, batch_size=32), loss_calculator: loss_calculators.LossCalculator = bare(loss_calculators.MLELoss), trainer: optimizers.XnmtOptimizer = bare(optimizers.SimpleSGDTrainer, e0=0.1), run_for_epochs: Optional[numbers.Integral] = None, lr_decay: numbers.Real= 1.0, lr_decay_times: numbers.Integral = 3, patience: numbers.Integral = 1, initial_patience: Optional[numbers.Integral] = None, dev_tasks: Sequence[eval_tasks.EvalTask] = None, dev_combinator: Optional[str] = None, restart_trainer: bool = False, reload_command: Optional[str] = None, name: str = "{EXP}", sample_train_sents: Optional[numbers.Integral] = None, max_num_train_sents: Optional[numbers.Integral] = None, max_src_len: Optional[numbers.Integral] = None, max_trg_len: Optional[numbers.Integral] = None, loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"), update_every: numbers.Integral = 1, commandline_args: dict = Ref("exp_global.commandline_args", default={})) -> None: super().__init__(model=model, src_file=src_file, trg_file=trg_file, dev_every=dev_every, batcher=batcher, loss_calculator=loss_calculator, run_for_epochs=run_for_epochs, lr_decay=lr_decay, lr_decay_times=lr_decay_times, patience=patience, initial_patience=initial_patience, dev_tasks=dev_tasks, dev_combinator=dev_combinator, restart_trainer=restart_trainer, reload_command=reload_command, name=name, sample_train_sents=sample_train_sents, max_num_train_sents=max_num_train_sents, max_src_len=max_src_len, max_trg_len=max_trg_len) self.dev_zero = dev_zero self.trainer = trainer or optimizers.SimpleSGDTrainer(e0=0.1) self.dynet_profiling = commandline_args.get("dynet_profiling", 0) if commandline_args else 0 self.train_loss_tracker = loss_trackers.TrainLossTracker(self) self.loss_comb_method = loss_comb_method self.update_every = update_every self.num_updates_skipped = 0
[docs] def run_training(self, save_fct: Callable) -> None: """ Main training loop (overwrites TrainingRegimen.run_training()) """ if self.run_for_epochs is None or self.run_for_epochs > 0: for src, trg in self.next_minibatch(): if self.dev_zero: self.checkpoint_and_save(save_fct) self.dev_zero = False with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}): dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) with self.train_loss_tracker.time_tracker: event_trigger.set_train(True) loss_builder = self.training_step(src, trg) loss = loss_builder.compute() self.backward(loss, self.dynet_profiling) self.update(self.trainer) self.train_loss_tracker.report(trg, loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method)) if self.checkpoint_needed(): self.checkpoint_and_save(save_fct) if self.should_stop_training(): break
def checkpoint_and_save(self, save_fct: Callable) -> None: should_save = self.checkpoint() if should_save: save_fct()
[docs] def update(self, trainer: optimizers.XnmtOptimizer) -> None: self.num_updates_skipped += 1 if self.num_updates_skipped == self.update_every: trainer.update() self.num_updates_skipped = 0 else: assert 0 < self.num_updates_skipped < self.update_every
[docs]class AutobatchTrainingRegimen(SimpleTrainingRegimen): """ This regimen overrides SimpleTrainingRegimen by accumulating (summing) losses into a FactoreLossExpr *before* running forward/backward in the computation graph. It is designed to work with DyNet autobatching and when parts of architecture make batching difficult (such as structured encoders like TreeLSTMS or Graph Networks). The actual batch size is set through the "update_every" parameter, while the underlying Batcher is expected to have "batch_size" equal to 1. Args: model: the model src_file: the source training file trg_file: the target training file dev_every: dev checkpoints every n sentences (0 for only after epoch) dev_zero: if True, add a checkpoint before training loop is entered (useful with pretrained models). batcher: Type of batcher loss_calculator: The method for calculating the loss. trainer: Trainer object, default is SGD with learning rate 0.1 run_for_epochs: lr_decay: lr_decay_times: Early stopping after decaying learning rate a certain number of times patience: apply LR decay after dev scores haven't improved over this many checkpoints initial_patience: if given, allows adjusting patience for the first LR decay dev_tasks: A list of tasks to use during the development stage. dev_combinator: A formula to combine together development scores into a single score to choose whether to perform learning rate decay, etc. e.g. 'x[0]-x[1]' would say that the first dev task score minus the second dev task score is our measure of how good we're doing. If not specified, only the score from the first dev task will be used. restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf) reload_command: Command to change the input data after each epoch. --epoch EPOCH_NUM will be appended to the command. To just reload the data after each epoch set the command to ``True``. name: will be prepended to log outputs if given sample_train_sents: max_num_train_sents: max_src_len: max_trg_len: loss_comb_method: method for combining loss across batch elements (``sum`` or ``avg``). update_every: how many instances to accumulate before updating parameters. This effectively sets the batch size under DyNet autobatching. commandline_args: """ yaml_tag = '!AutobatchTrainingRegimen' @serializable_init def __init__(self, model: models.ConditionedModel = Ref("model"), src_file: Union[None, str, Sequence[str]] = None, trg_file: Optional[str] = None, dev_every: numbers.Integral = 0, dev_zero: bool = False, batcher: batchers.Batcher = bare(batchers.SrcBatcher, batch_size=32), loss_calculator: loss_calculators.LossCalculator = bare(loss_calculators.MLELoss), trainer: optimizers.XnmtOptimizer = bare(optimizers.SimpleSGDTrainer, e0=0.1), run_for_epochs: Optional[numbers.Integral] = None, lr_decay: numbers.Real= 1.0, lr_decay_times: numbers.Integral = 3, patience: numbers.Integral = 1, initial_patience: Optional[numbers.Integral] = None, dev_tasks: Sequence[eval_tasks.EvalTask] = None, dev_combinator: Optional[str] = None, restart_trainer: bool = False, reload_command: Optional[str] = None, name: str = "{EXP}", sample_train_sents: Optional[numbers.Integral] = None, max_num_train_sents: Optional[numbers.Integral] = None, max_src_len: Optional[numbers.Integral] = None, max_trg_len: Optional[numbers.Integral] = None, loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"), update_every: numbers.Integral = 1, commandline_args: dict = Ref("exp_global.commandline_args", default={})) -> None: super().__init__(model=model, src_file=src_file, trg_file=trg_file, dev_every=dev_every, batcher=batcher, loss_calculator=loss_calculator, run_for_epochs=run_for_epochs, lr_decay=lr_decay, lr_decay_times=lr_decay_times, patience=patience, initial_patience=initial_patience, dev_tasks=dev_tasks, dev_combinator=dev_combinator, restart_trainer=restart_trainer, reload_command=reload_command, name=name, sample_train_sents=sample_train_sents, max_num_train_sents=max_num_train_sents, max_src_len=max_src_len, max_trg_len=max_trg_len) if batcher.batch_size != 1: raise ValueError("AutobatchTrainingRegimen forces the batcher to have batch_size 1. Use update_every to set the actual batch size in this regimen.") self.dev_zero = dev_zero self.trainer = trainer or optimizers.SimpleSGDTrainer(e0=0.1) self.dynet_profiling = commandline_args.get("dynet_profiling", 0) if commandline_args else 0 self.train_loss_tracker = loss_trackers.TrainLossTracker(self) self.loss_comb_method = loss_comb_method self.update_every = update_every self.num_updates_skipped = 0
[docs] def run_training(self, save_fct: Callable) -> None: """ Main training loop (overwrites TrainingRegimen.run_training()) """ dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) if self.run_for_epochs is None or self.run_for_epochs > 0: total_loss = losses.FactoredLossExpr() # Needed for report total_trg = [] for src, trg in self.next_minibatch(): if self.dev_zero: self.checkpoint_and_save(save_fct) self.dev_zero = False with utils.ReportOnException({"src": src, "trg": trg, "graph": utils.print_cg_conditional}): with self.train_loss_tracker.time_tracker: event_trigger.set_train(True) total_trg.append(trg[0]) loss_builder = self.training_step(src, trg) total_loss.add_factored_loss_expr(loss_builder) # num_updates_skipped is incremented in update but # we need to call backward before update if self.num_updates_skipped == self.update_every - 1: self.backward(total_loss.compute(), self.dynet_profiling) self.update(self.trainer) if self.num_updates_skipped == 0: total_loss_val = total_loss.get_factored_loss_val(comb_method=self.loss_comb_method) reported_trg = batchers.ListBatch(total_trg) self.train_loss_tracker.report(reported_trg, total_loss_val) total_loss = losses.FactoredLossExpr() total_trg = [] dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) if self.checkpoint_needed(): # Do a last update before checkpoint # Force forward-backward for the last batch even if it's smaller than update_every self.num_updates_skipped = self.update_every - 1 self.backward(total_loss.compute(), self.dynet_profiling) self.update(self.trainer) total_loss_val = total_loss.get_factored_loss_val(comb_method=self.loss_comb_method) reported_trg = batchers.ListBatch(total_trg) self.train_loss_tracker.report(reported_trg, total_loss_val) total_loss = losses.FactoredLossExpr() total_trg = [] self.checkpoint_and_save(save_fct) if self.should_stop_training(): break
[docs]class MultiTaskTrainingRegimen(TrainingRegimen): """ Base class for multi-task training classes. Mainly initializes tasks, performs sanity-checks, and manages set_train events. Args: tasks: list of training tasks. The first item takes on the role of the main task, meaning it will control early stopping, learning rate schedule, and model checkpoints. trainer: Trainer object, default is SGD with learning rate 0.1 dev_zero: if True, add a checkpoint before training loop is entered (useful with pretrained models). update_every: simulate large-batch training by accumulating gradients over several steps before updating parameters commandline_args: """ def __init__(self, tasks: Sequence[train_tasks.TrainingTask], trainer: optimizers.XnmtOptimizer = bare(optimizers.SimpleSGDTrainer, e0=0.1), dev_zero: bool = False, update_every: numbers.Integral = 1, commandline_args: dict = Ref("exp_global.commandline_args", default=None)) -> None: super().__init__() self.dynet_profiling = commandline_args.get("dynet_profiling", 0) if commandline_args else 0 if len(tasks)==0: raise ValueError("Task list must be non-empty.") self.tasks = tasks self.trainer = trainer for task in tasks[1:]: if hasattr(task, "trainer") and task.trainer is not None: raise ValueError("Can instantiate only one trainer object. Possibly, multiple training regimens were created when training tasks should have been used.") self.train = None self.model_file = param_collections.ParamManager.param_col.model_file for task in tasks: task.trainer = trainer self.dev_zero = dev_zero self.update_every = update_every self.num_updates_skipped = 0
[docs] def trigger_train_event(self, value: bool) -> None: """ Trigger set_train event, but only if that would lead to a change of the value of set_train. Args: value: True or False """ if self.train is None: self.train = value event_trigger.set_train(value) else: if value!=self.train: self.train = value event_trigger.set_train(value)
[docs] def update(self, trainer: optimizers.XnmtOptimizer) -> None: self.num_updates_skipped += 1 if self.num_updates_skipped == self.update_every: trainer.update() self.num_updates_skipped = 0 else: assert 0 < self.num_updates_skipped < self.update_every
[docs]class SameBatchMultiTaskTrainingRegimen(MultiTaskTrainingRegimen, Serializable): """ Multi-task training where gradients are accumulated and weight updates are thus performed jointly for each task. The relative weight between tasks can be configured setting the number of steps to accumulate over for each task. Note that the batch size for each task also has an influence on task weighting. The stopping criterion of the first task is used (other tasks' stopping criteria are ignored). Args: tasks: Training tasks trainer: The trainer is shared across tasks dev_zero: If ``True``, add a checkpoint before training loop is entered (useful with pretrained models). per_task_backward: If ``True``, call backward() for each task separately and renew computation graph between tasks. Yields the same results, but ``True`` uses less memory while ``False`` may be faster when using autobatching. loss_comb_method: Method for combining loss across batch elements ('sum' or 'avg'). update_every: Simulate large-batch training by accumulating gradients over several steps before updating parameters. This is implemented as an outer loop, i.e. we first accumulate gradients from steps for each task, and then loop according to this parameter so that we collect multiple steps for each task and always according to the same ratio. n_task_steps: The number steps to accumulate for each task, useful for weighting tasks. commandline_args: """ yaml_tag = "!SameBatchMultiTaskTrainingRegimen" @serializable_init def __init__(self, tasks: Sequence[train_tasks.TrainingTask], trainer: optimizers.XnmtOptimizer = bare(optimizers.SimpleSGDTrainer, e0=0.1), dev_zero: bool = False, per_task_backward: bool = True, loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"), update_every: numbers.Integral = 1, n_task_steps: Optional[Sequence[numbers.Integral]] = None, commandline_args: dict = Ref("exp_global.commandline_args", default=None)) -> None: super().__init__(tasks=tasks, trainer=trainer, dev_zero=dev_zero, update_every=update_every, commandline_args=commandline_args) self.train_loss_trackers = {task : loss_trackers.TrainLossTracker(task) for task in tasks} self.per_task_backward = per_task_backward self.loss_comb_method = loss_comb_method self.n_task_steps = n_task_steps or [1] * len(tasks) if len(self.n_task_steps) != len(tasks): raise ValueError(f"number of tasks and steps per task do not match: {len(tasks)} != {len(self.n_task_steps)}")
[docs] def run_training(self, save_fct: Callable) -> None: task_generators = OrderedDict() for task in self.tasks: task_generators[task] = task.next_minibatch() if self.tasks[0].run_for_epochs > 0: while True: task_losses = [] task_src_trg = [] for (task, task_gen), task_n in zip(task_generators.items(), self.n_task_steps): for _ in range(task_n): src, trg = next(task_gen) task_src_trg.append((task, src, trg)) if self.dev_zero: # True only in first iteration self.checkpoint_and_save(save_fct) dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) task_trg_loss_stats = {} with contextlib.ExitStack() as stack: #use exit stack to control whether to use global or per-task time tracking if not self.per_task_backward: stack.enter_context(self.train_loss_trackers[self.tasks[0]].time_tracker) self.trigger_train_event(True) for task, src, trg in task_src_trg: with contextlib.ExitStack() as stack2: if self.per_task_backward: stack2.enter_context(self.train_loss_trackers[task].time_tracker) loss_builder = task.training_step(src, trg) task_trg_loss_stats[task] = (trg, loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method)) if self.per_task_backward: self.backward(loss_builder.compute(), self.dynet_profiling) dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) else: task_losses.append(loss_builder.compute()) if not self.per_task_backward: self.backward(sum(task_losses), self.dynet_profiling) self.update(self.trainer) for task, (trg, stats) in task_trg_loss_stats.items(): self.train_loss_trackers[task].report(trg, stats) self.checkpoint_and_save(save_fct) if self.tasks[0].should_stop_training(): break
def checkpoint_and_save(self, save_fct: Callable) -> None: for task_i, task in enumerate(self.tasks): if self.dev_zero or task.checkpoint_needed(): should_save = task.checkpoint(control_learning_schedule=(task_i == 0)) if should_save: save_fct() self.dev_zero = False
[docs]class AlternatingBatchMultiTaskTrainingRegimen(MultiTaskTrainingRegimen, Serializable): """ Multi-task training where training steps are performed one after another. The relative weight between tasks are explicitly specified explicitly, and for each step one task is drawn at random accordingly. The stopping criterion of the first task is used (other tasks' stopping criteria are ignored). Args: tasks: training tasks trainer: the trainer is shared across tasks dev_zero: if True, add a checkpoint before training loop is entered (useful with pretrained models). loss_comb_method: method for combining loss across batch elements ('sum' or 'avg'). update_every_within: Simulate large-batch training by accumulating gradients over several steps before updating parameters. The behavior here is to draw multiple times from the same task until update is invoked. update_every_across: Simulate large-batch training by accumulating gradients over several steps before updating parameters. The behavior here is to draw tasks randomly several times before doing parameter updates. commandline_args: """ yaml_tag = "!AlternatingBatchMultiTaskTrainingRegimen" @serializable_init def __init__(self, tasks: Sequence[train_tasks.TrainingTask], task_weights: Optional[Sequence[numbers.Real]] = None, trainer: optimizers.XnmtOptimizer = bare(optimizers.SimpleSGDTrainer, e0=0.1), dev_zero: bool = False, loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"), update_every_within: numbers.Integral = 1, update_every_across: numbers.Integral = 1, commandline_args=Ref("exp_global.commandline_args", default=None)) -> None: super().__init__(tasks=tasks, trainer=trainer, dev_zero=dev_zero, update_every=update_every_across, commandline_args=commandline_args) if update_every_within!=1 and update_every_across!=1: raise ValueError("update_every_within and update_every_across cannot be mixed.") self.update_every_within = update_every_within self.task_weights = task_weights or [1./len(tasks)] * len(tasks) if len(self.task_weights) != len(self.tasks): raise ValueError(f"number of tasks must match number of task weights; " f"found: {len(self.task_weights)} != {len(self.tasks)}") self.train_loss_trackers = {task: loss_trackers.TrainLossTracker(task) for task in tasks} self.loss_comb_method = loss_comb_method
[docs] def run_training(self, save_fct: Callable) -> None: task_generators = OrderedDict() for task in self.tasks: task_generators[task] = task.next_minibatch() dev_zero = {i:self.dev_zero for i in range(len(self.tasks))} if self.tasks[0].run_for_epochs > 0: while True: dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) cur_task_i = np.random.choice(range(len(self.tasks)), p=self.task_weights) cur_task = self.tasks[cur_task_i] task_gen = task_generators[cur_task] if dev_zero[cur_task_i]: self.checkpoint_and_save(cur_task, cur_task_i, save_fct, dev_zero) cur_train_loss_tracker = self.train_loss_trackers[cur_task] with cur_train_loss_tracker.time_tracker: for _ in range(self.update_every_within): src, trg = next(task_gen) self.trigger_train_event(True) loss_builder = cur_task.training_step(src, trg) self.backward(loss=loss_builder.compute(), dynet_profiling=self.dynet_profiling) self.update(trainer=self.trainer) cur_train_loss_tracker.report(trg, loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method)) self.checkpoint_and_save(cur_task, cur_task_i, save_fct, dev_zero) if self.tasks[0].should_stop_training(): break
def checkpoint_and_save(self, cur_task: train_tasks.TrainingTask, cur_task_i: numbers.Integral, save_fct: Callable, dev_zero: Dict[numbers.Integral,bool]) -> None: if dev_zero[cur_task_i] or cur_task.checkpoint_needed(): dev_zero[cur_task_i] = False should_save = cur_task.checkpoint(control_learning_schedule=(cur_task_i == 0)) if should_save: save_fct()
[docs]class SerialMultiTaskTrainingRegimen(MultiTaskTrainingRegimen, Serializable): """ Trains only first task until stopping criterion met, then the same for the second task, etc. Useful to realize a pretraining-finetuning strategy. Args: tasks: training tasks. The currently active task is treated as main task. trainer: the trainer is shared across tasks dev_zero: if True, add a checkpoint before training loop is entered (useful with pretrained models). loss_comb_method: method for combining loss across batch elements ('sum' or 'avg'). update_every: simulate large-batch training by accumulating gradients over several steps before updating parameters commandline_args: """ yaml_tag = "!SerialMultiTaskTrainingRegimen" @serializable_init def __init__(self, tasks: Sequence[train_tasks.TrainingTask], trainer: optimizers.XnmtOptimizer = bare(optimizers.SimpleSGDTrainer, e0=0.1), dev_zero: bool = False, loss_comb_method: str = Ref("exp_global.loss_comb_method", default="sum"), update_every: numbers.Integral = 1, commandline_args: dict = Ref("exp_global.commandline_args", default=None)) -> None: super().__init__(tasks=tasks, trainer=trainer, dev_zero=dev_zero, commandline_args=commandline_args, update_every=update_every) self.train_loss_trackers = {task: loss_trackers.TrainLossTracker(task) for task in tasks} self.loss_comb_method = loss_comb_method
[docs] def run_training(self, save_fct: Callable) -> None: dev_zero = {i:self.dev_zero for i in range(len(self.tasks))} for cur_task_id in range(len(self.tasks)): self.train = None cur_task = self.tasks[cur_task_id] cur_train_loss_tracker = self.train_loss_trackers[cur_task] task_gen = cur_task.next_minibatch() if cur_task.run_for_epochs > 0: while True: dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY) src, trg = next(task_gen) if dev_zero[cur_task_id]: self.checkpoint_and_save(cur_task, cur_task_id, save_fct, dev_zero) with cur_train_loss_tracker.time_tracker: self.trigger_train_event(True) loss_builder = cur_task.training_step(src, trg) task_loss = loss_builder.compute() self.backward(task_loss, self.dynet_profiling) self.update(self.trainer) cur_train_loss_tracker.report(trg, loss_builder.get_factored_loss_val(comb_method=self.loss_comb_method)) self.checkpoint_and_save(cur_task, cur_task_id, save_fct, dev_zero) if cur_task.should_stop_training(): break
def checkpoint_and_save(self, cur_task: train_tasks.TrainingTask, cur_task_id: numbers.Integral, save_fct: Callable, dev_zero: Dict[numbers.Integral,bool]) -> None: if dev_zero[cur_task_id] or cur_task.checkpoint_needed(): dev_zero[cur_task_id] = False should_save = cur_task.checkpoint(control_learning_schedule=True) if should_save: save_fct()