from subprocess import Popen
from asteval import Interpreter
import random
from typing import Iterator, Optional, Sequence, Union
import numbers
import numpy as np
from xnmt import batchers, event_trigger, input_readers, logger, losses, loss_trackers, loss_calculators, \
param_collections
from xnmt.models import base as model_base
from xnmt.eval import tasks as eval_tasks
from xnmt.persistence import serializable_init, Serializable, bare
[docs]class TrainingTask(object):
"""
Base class for a training task. Training tasks can perform training steps
and keep track of the training state, but may not implement the actual training
loop.
Args:
model: The model to train
"""
def __init__(self, model: 'model_base.TrainableModel') -> None:
self.model = model
[docs] def should_stop_training(self):
"""
Returns:
True iff training is finished, i.e. training_step(...) should not be called again
"""
raise NotImplementedError("must be implemented by subclasses")
[docs] def training_step(self, **kwargs) -> 'losses.FactoredLossExpr':
"""
Perform forward pass for the next training step and handle training logic (switching epoch, reshuffling, ..)
Args:
**kwargs: depends on subclass implementations
Returns:
Loss
"""
raise NotImplementedError("must be implemented by subclasses")
[docs] def next_minibatch(self) -> Iterator:
"""
Infinitely loop over training minibatches.
Returns:
Generator yielding (src_batch,trg_batch) tuples
"""
def checkpoint_needed(self) -> bool:
raise NotImplementedError("must be implemented by subclasses")
[docs] def checkpoint(self, control_learning_schedule: bool = False) -> bool:
"""
Perform a dev checkpoint.
Args:
control_learning_schedule: If ``False``, only evaluate dev data.
If ``True``, also perform model saving, LR decay etc. if needed.
Returns:
``True`` iff the model needs saving
"""
raise NotImplementedError("must be implemented by subclasses")
[docs] def cur_num_minibatches(self) -> int:
"""
Current number of minibatches (may change between epochs, e.g. for randomizing batchers or if reload_command is given)
"""
raise NotImplementedError("must be implemented by subclasses")
[docs] def cur_num_sentences(self) -> int:
"""
Current number of parallel sentences (may change between epochs, e.g. if reload_command is given)
"""
raise NotImplementedError("must be implemented by subclasses")
[docs]class SimpleTrainingTask(TrainingTask, Serializable):
"""
Args:
model: a trainable supervised model
src_file: The file for the source data.
trg_file: The file for the target data.
dev_every: dev checkpoints every n sentences (0 for only after epoch)
batcher: Type of batcher
loss_calculator:
run_for_epochs: number of epochs (None for unlimited epochs)
lr_decay: decay learning rate by multiplying by this factor
lr_decay_times: Early stopping after decaying learning rate a certain number of times
patience: apply LR decay after dev scores haven't improved over this many checkpoints
initial_patience: if given, allows adjusting patience for the first LR decay
dev_tasks: A list of tasks to run on the development set
dev_combinator: A formula to combine together development scores into a single score to
choose whether to perform learning rate decay, etc.
e.g. 'x[0]-x[1]' would say that the first dev task score minus the
second dev task score is our measure of how good we're doing. If not
specified, only the score from the first dev task will be used.
restart_trainer: Restart trainer (useful for Adam) and revert weights to best dev checkpoint when applying LR decay (https://arxiv.org/pdf/1706.09733.pdf)
reload_command: Command to change the input data after each epoch.
--epoch EPOCH_NUM will be appended to the command.
To just reload the data after each epoch set the command to 'true'.
sample_train_sents: If given, load a random subset of training sentences before each epoch. Useful when training data does not fit in memory.
max_num_train_sents: Train only on the first n sentences
max_src_len: Discard training sentences with source-side longer than this
max_trg_len: Discard training sentences with target-side longer than this
name: will be prepended to log outputs if given
"""
yaml_tag = '!SimpleTrainingTask'
@serializable_init
def __init__(self,
model: 'model_base.ConditionedModel',
src_file: Union[str, Sequence[str]] = None,
trg_file: str = None,
dev_every: numbers.Integral = 0,
batcher: batchers.Batcher = bare(batchers.SrcBatcher, batch_size=32),
loss_calculator: loss_calculators.LossCalculator = bare(loss_calculators.MLELoss),
run_for_epochs: Optional[numbers.Integral] = None,
lr_decay: numbers.Real = 1.0,
lr_decay_times: numbers.Integral = 3,
patience: numbers.Integral = 1,
initial_patience: Optional[numbers.Integral] = None,
dev_tasks: Sequence['eval_tasks.EvalTask'] = None,
dev_combinator=None,
restart_trainer: bool = False,
reload_command: Optional[str] = None,
name: Optional[str] = None,
sample_train_sents: Optional[numbers.Integral] = None,
max_num_train_sents: Optional[numbers.Integral] = None,
max_src_len: Optional[numbers.Integral] = None,
max_trg_len: Optional[numbers.Integral] = None) -> None:
self.src_file = src_file
self.trg_file = trg_file
self.dev_tasks = dev_tasks
self.dev_combinator = dev_combinator
if lr_decay > 1.0 or lr_decay <= 0.0:
raise RuntimeError("illegal lr_decay, must satisfy: 0.0 < lr_decay <= 1.0")
self.lr_decay = lr_decay
self.patience = patience
self.initial_patience = initial_patience
self.lr_decay_times = lr_decay_times
self.restart_trainer = restart_trainer
self.run_for_epochs = run_for_epochs
self.early_stopping_reached = False
# training state
self.training_state = TrainingState()
self.reload_command = reload_command
self.model = model
self.loss_calculator = loss_calculator
self.sample_train_sents = sample_train_sents
self.max_num_train_sents = max_num_train_sents
self.max_src_len = max_src_len
self.max_trg_len = max_trg_len
self.batcher = batcher
self.dev_loss_tracker = loss_trackers.DevLossTracker(self, dev_every, name)
self.name = name
def _augment_data_initial(self):
"""
Called before loading corpus for the first time, if reload_command is given
"""
augment_command = self.reload_command
logger.debug('initial augmentation')
if self._augmentation_handle is None:
# first run
self._augmentation_handle = Popen(augment_command + " --epoch 0", shell=True)
self._augmentation_handle.wait()
def _augment_data_next_epoch(self):
"""
This is run in the background if reload_command is given to prepare data for the next epoch
"""
augment_command = self.reload_command
if self._augmentation_handle is None:
# first run
self._augmentation_handle = Popen(augment_command + " --epoch %d" % self.training_state.epoch_num, shell=True)
self._augmentation_handle.wait()
self._augmentation_handle.poll()
retcode = self._augmentation_handle.returncode
if retcode is not None:
if self.training_state.epoch_num > 0:
logger.info('using reloaded data')
# reload the data
self.src_data, self.trg_data, self.src_batches, self.trg_batches = \
input_readers.read_parallel_corpus(src_reader=self.model.src_reader,
trg_reader=self.model.trg_reader,
src_file=self.src_file,
trg_file=self.trg_file,
batcher=self.batcher,
sample_sents=self.sample_train_sents,
max_num_sents=self.max_num_train_sents,
max_src_len=self.max_src_len,
max_trg_len=self.max_trg_len)
self.model.src_reader.train = self.model.trg_reader.train = False
# restart data generation
self._augmentation_handle = Popen(augment_command + " --epoch %d" % self.training_state.epoch_num, shell=True)
else:
logger.info('new data set is not ready yet, using data from last epoch.')
[docs] def should_stop_training(self) -> bool:
"""
Signal stopping if self.early_stopping_reached is marked or we exhausted the number of requested epochs.
"""
return self.early_stopping_reached \
or self.run_for_epochs is not None and (self.training_state.epoch_num > self.run_for_epochs
or (self.training_state.epoch_num == self.run_for_epochs and
self.training_state.steps_into_epoch >= self.cur_num_minibatches()))
[docs] def cur_num_minibatches(self) -> numbers.Integral:
"""
Current number of minibatches (may change between epochs, e.g. for randomizing batchers or if reload_command is given)
"""
return len(self.src_batches)
[docs] def cur_num_sentences(self) -> numbers.Integral:
"""
Current number of parallel sentences (may change between epochs, e.g. if reload_command is given)
"""
return len(self.src_data)
def _advance_epoch(self):
"""
Shifts internal state to the next epoch, including data (re-)loading, batch re-packing and shuffling.
"""
if self.reload_command is not None:
if self.training_state.epoch_num==0:
self._augmentation_handle = None
self._augment_data_initial()
else:
self._augment_data_next_epoch()
if self.training_state.epoch_num==0 or self.sample_train_sents or \
self.model.src_reader.needs_reload() or self.model.trg_reader.needs_reload():
event_trigger.set_train(True)
self.src_data, self.trg_data, self.src_batches, self.trg_batches = \
input_readers.read_parallel_corpus(src_reader=self.model.src_reader, trg_reader=self.model.trg_reader,
src_file=self.src_file, trg_file=self.trg_file,
batcher=self.batcher, sample_sents=self.sample_train_sents,
max_num_sents=self.max_num_train_sents,
max_src_len=self.max_src_len, max_trg_len=self.max_trg_len)
self.model.src_reader.train = self.model.trg_reader.train = False
self.training_state.epoch_seed = random.randint(1,2147483647)
random.seed(self.training_state.epoch_seed)
np.random.seed(self.training_state.epoch_seed)
self.src_batches, self.trg_batches = \
self.batcher.pack(self.src_data, self.trg_data)
self.training_state.epoch_num += 1
self.training_state.steps_into_epoch = 0
self.training_state.sents_into_epoch = 0
self.minibatch_order = list(range(0, self.cur_num_minibatches()))
np.random.shuffle(self.minibatch_order)
event_trigger.new_epoch(training_task=self, num_sents=self.cur_num_sentences())
[docs] def next_minibatch(self) -> Iterator:
"""
Infinitely loops over training minibatches and advances internal epoch state after every complete sweep over the corpus.
Returns:
Generator yielding (src_batch,trg_batch) tuples
"""
while True:
self._advance_epoch()
for batch_num in self.minibatch_order:
src = self.src_batches[batch_num]
trg = self.trg_batches[batch_num]
self.training_state.steps_into_epoch += 1
self.training_state.sents_into_epoch += src.batch_size()
self.training_state.sents_since_start += src.batch_size()
yield src, trg
[docs] def training_step(self, src: batchers.Batch, trg: batchers.Batch):
"""
Perform forward pass for the next training step and handle training logic (switching epoch, reshuffling, ..)
Args:
src: src minibatch
trg: trg minibatch
Returns:
Loss
"""
return self.loss_calculator.calc_loss(self.model, src, trg)
def checkpoint_needed(self):
return self.dev_loss_tracker.should_report_dev()
[docs] def checkpoint(self, control_learning_schedule: bool = True):
"""
Performs a dev checkpoint
Args:
control_learning_schedule: If False, only evaluate dev data.
If True, also perform model saving, LR decay etc. if needed.
Returns:
True if the model needs saving, False otherwise
"""
# Perform evaluation
if self.dev_tasks and len(self.dev_tasks) > 0:
dev_scores = []
with self.dev_loss_tracker.time_tracker:
logger.info(f"> Checkpoint [{self.name}]" if self.name else "> Checkpoint")
for dev_task in self.dev_tasks:
dev_score = dev_task.eval()
if type(dev_score) == list:
dev_scores.extend(dev_score)
else:
dev_scores.append(dev_score)
self.dev_loss_tracker.set_dev_score(dev_scores[0])
for dev_score in dev_scores[1:]:
self.dev_loss_tracker.add_aux_score(dev_score)
self.dev_loss_tracker.report()
# Control the learning schedule
if control_learning_schedule:
# Check if this is the best
is_best = False
if self.dev_combinator is not None:
x = [y.value() for y in dev_scores]
aevala = Interpreter(symtable={'x': x})
my_score = aevala(self.dev_combinator)
logger.info(' combined dev scores according to {}: {}'.format(self.dev_combinator, my_score))
if self.training_state.best_dev_score is None or my_score > self.training_state.best_dev_score:
self.training_state.best_dev_score = my_score
is_best = True
elif dev_scores[0].better_than(self.training_state.best_dev_score):
self.training_state.best_dev_score = dev_scores[0]
is_best = True
# If this is the best, write the model out
if is_best:
self.training_state.cur_attempt = 0
needs_saving = True
logger.info(f" best dev score, writing out model")
else:
needs_saving = False
# otherwise: learning rate decay / early stopping
self.training_state.cur_attempt += 1
if self.lr_decay < 1.0:
should_decay = False
if (self.initial_patience is None or self.training_state.num_times_lr_decayed>0) \
and self.training_state.cur_attempt >= self.patience:
should_decay = True
if self.initial_patience is not None and self.training_state.num_times_lr_decayed==0 \
and self.training_state.cur_attempt >= self.initial_patience:
should_decay = True
if should_decay:
self.training_state.num_times_lr_decayed += 1
if self.training_state.num_times_lr_decayed > self.lr_decay_times:
logger.info(' Early stopping')
self.early_stopping_reached = True
else:
self.training_state.cur_attempt = 0
self.trainer.learning_rate *= self.lr_decay
logger.info(' new learning rate: %s' % self.trainer.learning_rate)
if self.restart_trainer:
logger.info(' restarting trainer and reverting learned weights to best checkpoint..')
self.trainer.restart()
param_collections.ParamManager.param_col.revert_to_best_model()
else: # case of not controling learning schedule
needs_saving = False
else: # case of no dev tasks
needs_saving = True
return needs_saving
[docs]class TrainingState(object):
"""
This holds the state of the training loop.
"""
def __init__(self) -> None:
self.num_times_lr_decayed = 0
self.cur_attempt = 0
self.epoch_num = 0
self.steps_into_epoch = 0
self.sents_since_start = 0
self.sents_into_epoch = 0
self.best_dev_score = None
# used to pack and shuffle minibatches (keeping track might help resuming crashed trainings in the future)
self.epoch_seed = random.randint(1,2147483647)