import warnings
from typing import Callable, List, Optional, Sequence, Tuple, Union
import math
import random
from abc import ABC, abstractmethod
from functools import lru_cache
import numbers
import numpy as np
import dynet as dy
from xnmt.persistence import serializable_init, Serializable
from xnmt import expression_seqs
from xnmt.transducers import recurrent
from xnmt import sent
[docs]class Batch(ABC):
"""
An abstract base class for minibatches of things.
"""
@abstractmethod
def batch_size(self) -> int:
raise NotImplementedError()
def sent_len(self) -> int:
raise NotImplementedError()
[docs]class ListBatch(list, Batch):
"""
A class containing a minibatch of things.
This class behaves like a Python list, but adds semantics that the contents form a (mini)batch of things.
An optional mask can be specified to indicate padded parts of the inputs.
Should be treated as an immutable object.
Args:
batch_elements: list of things
mask: optional mask when batch contains items of unequal size
"""
def __init__(self, batch_elements: list, mask: 'Mask'=None) -> None:
assert len(batch_elements)>0
super().__init__(batch_elements)
self.mask = mask
def batch_size(self) -> int: return super().__len__()
def sent_len(self) -> int: return self[0].sent_len()
def __len__(self):
warnings.warn("use of ListBatch.__len__() is discouraged, use ListBatch.batch_size() "
"[or ListBatch.sent_len()] instead.", DeprecationWarning)
return self.batch_size()
def __getitem__(self, key):
ret = super().__getitem__(key)
if isinstance(key, slice):
ret = ListBatch(ret)
return ret
[docs]class CompoundBatch(Batch):
"""
A compound batch contains several parallel batches.
Args:
*batch_elements: one or several batches
"""
def __init__(self, *batch_elements: Batch) -> None:
assert len(batch_elements) > 0
self.batches = batch_elements
def batch_size(self) -> numbers.Integral:
return self.batches[0].batch_size()
def sent_len(self) -> numbers.Integral:
return sum(b.sent_len() for b in self.batches)
def __iter__(self):
for i in range(self.batch_size()):
yield sent.CompoundSentence(sents=[b[i] for b in self.batches])
def __getitem__(self, key):
if isinstance(key, numbers.Integral):
return sent.CompoundSentence(sents=[b[key] for b in self.batches])
else:
assert isinstance(key, slice)
sel_batches = [b[key] for b in self.batches]
return CompoundBatch(sel_batches)
[docs]class Mask(object):
"""
An immutable mask specifies padded parts in a sequence or batch of sequences.
Masks are represented as numpy array of dimensions batchsize x seq_len, with parts
belonging to the sequence set to 0, and parts that should be masked set to 1
Args:
np_arr: numpy array
"""
def __init__(self, np_arr: np.ndarray) -> None:
self.np_arr = np_arr
self.np_arr.flags.writeable = False
def __len__(self):
return self.np_arr.shape[1]
def batch_size(self) -> numbers.Integral:
return self.np_arr.shape[0]
def reversed(self) -> 'Mask':
return Mask(self.np_arr[:,::-1])
def add_to_tensor_expr(self, tensor_expr: dy.Expression, multiplicator: Optional[numbers.Real]=None) -> dy.Expression:
# TODO: might cache these expressions to save memory
if np.count_nonzero(self.np_arr) == 0:
return tensor_expr
else:
if multiplicator is not None:
mask_expr = dy.inputTensor(np.expand_dims(self.np_arr.transpose(), axis=1) * multiplicator, batched=True)
else:
mask_expr = dy.inputTensor(np.expand_dims(self.np_arr.transpose(), axis=1), batched=True)
return tensor_expr + mask_expr
def lin_subsampled(self, reduce_factor: Optional[numbers.Integral] = None, trg_len: Optional[numbers.Integral]=None) -> 'Mask':
if reduce_factor:
return Mask(np.array([[self.np_arr[b,int(i*reduce_factor)] for i in range(int(math.ceil(len(self)/float(reduce_factor))))] for b in range(self.batch_size())]))
else:
return Mask(np.array([[self.np_arr[b,int(i*len(self)/float(trg_len))] for i in range(trg_len)] for b in range(self.batch_size())]))
[docs] def cmult_by_timestep_expr(self, expr: dy.Expression, timestep: numbers.Integral, inverse: bool = False) -> dy.Expression:
# TODO: might cache these expressions to save memory
"""
Args:
expr: a dynet expression corresponding to one timestep
timestep: index of current timestep
inverse: True will keep the unmasked parts, False will zero out the unmasked parts
"""
if inverse:
if np.count_nonzero(self.np_arr[:,timestep:timestep+1]) == 0:
return expr
mask_exp = dy.inputTensor((1.0 - self.np_arr)[:,timestep:timestep+1].transpose(), batched=True)
else:
if np.count_nonzero(self.np_arr[:,timestep:timestep+1]) == self.np_arr[:,timestep:timestep+1].size:
return expr
mask_exp = dy.inputTensor(self.np_arr[:,timestep:timestep+1].transpose(), batched=True)
return dy.cmult(expr, mask_exp)
@lru_cache(maxsize=1)
def get_valid_position(self, transpose: bool = True) -> List[numbers.Integral]:
np_arr = self.np_arr
if transpose: np_arr = np_arr.transpose()
x = [np.nonzero(1-arr)[0] for arr in np_arr]
return x
[docs]class Batcher(object):
"""
A template class to convert a list of sentences to several batches of sentences.
Args:
batch_size: batch size
granularity: 'sent' or 'word'
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
sort_within_by_trg_len: whether to sort by reverse trg len inside a batch
"""
def __init__(self,
batch_size: numbers.Integral,
granularity: str = 'sent',
pad_src_to_multiple: numbers.Integral = 1,
sort_within_by_trg_len: bool = True) -> None:
self.batch_size = batch_size
self.granularity = granularity
self.pad_src_to_multiple = pad_src_to_multiple
self.sort_within_by_trg_len = sort_within_by_trg_len
[docs] def is_random(self) -> bool:
"""
Returns:
True if there is some randomness in the batching process, False otherwise.
"""
return False
[docs] def create_single_batch(self,
src_sents: Sequence[sent.Sentence],
trg_sents: Optional[Sequence[sent.Sentence]] = None,
sort_by_trg_len: bool = False) -> Union[Batch, Tuple[Batch]]:
"""
Create a single batch, either source-only or source-and-target.
Args:
src_sents: list of source-side inputs
trg_sents: optional list of target-side inputs
sort_by_trg_len: if True (and targets are specified), sort source- and target batches by target length
Returns:
a tuple of batches if targets were given, otherwise a single batch
"""
if trg_sents is not None and sort_by_trg_len:
src_sents, trg_sents = zip(*sorted(zip(src_sents, trg_sents), key=lambda x: x[1].sent_len(), reverse=True))
src_batch = pad(src_sents, pad_to_multiple=self.pad_src_to_multiple)
if trg_sents is None:
return src_batch
else:
trg_batch = pad(trg_sents)
return src_batch, trg_batch
def _add_single_batch(self, src_curr, trg_curr, src_ret, trg_ret, sort_by_trg_len=False):
if trg_curr:
src_batch, trg_batch = self.create_single_batch(src_curr, trg_curr, sort_by_trg_len)
trg_ret.append(trg_batch)
else:
src_batch = self.create_single_batch(src_curr, trg_curr, sort_by_trg_len)
src_ret.append(src_batch)
def _pack_by_order(self,
src: Sequence[sent.Sentence],
trg: Optional[Sequence[sent.Sentence]],
order: Sequence[numbers.Integral]) -> Tuple[Sequence[Batch], Sequence[Batch]]:
"""
Pack batches by given order.
Trg is optional for the case of self.granularity == 'sent'
Args:
src: src-side inputs
trg: trg-side inputs
order: order of inputs
Returns:
If trg is given: tuple of src / trg batches; Otherwise: only src batches
"""
src_ret, src_curr = [], []
trg_ret, trg_curr = [], []
if self.granularity == 'sent':
for x in range(0, len(order), self.batch_size):
src_selected = [src[y] for y in order[x:x + self.batch_size]]
if trg:
trg_selected = [trg[y] for y in order[x:x + self.batch_size]]
else: trg_selected = None
self._add_single_batch(src_selected,
trg_selected,
src_ret, trg_ret,
sort_by_trg_len=self.sort_within_by_trg_len)
elif self.granularity == 'word':
max_src, max_trg = 0, 0
for i in order:
max_src = max(_len_or_zero(src[i]), max_src)
max_trg = max(_len_or_zero(trg[i]), max_trg)
if (max_src + max_trg) * (len(src_curr) + 1) > self.batch_size and len(src_curr) > 0:
self._add_single_batch(src_curr, trg_curr, src_ret, trg_ret, sort_by_trg_len=self.sort_within_by_trg_len)
max_src = _len_or_zero(src[i])
max_trg = _len_or_zero(trg[i])
src_curr = [src[i]]
trg_curr = [trg[i]]
else:
src_curr.append(src[i])
trg_curr.append(trg[i])
self._add_single_batch(src_curr, trg_curr, src_ret, trg_ret, sort_by_trg_len=self.sort_within_by_trg_len)
else:
raise RuntimeError("Illegal granularity specification {}".format(self.granularity))
if trg:
return src_ret, trg_ret
else:
return src_ret
[docs] def pack(self, src: Sequence[sent.Sentence], trg: Sequence[sent.Sentence]) \
-> Tuple[Sequence[Batch], Sequence[Batch]]:
"""
Create a list of src/trg batches based on provided src/trg inputs.
Args:
src: list of src-side inputs
trg: list of trg-side inputs
Returns:
tuple of lists of src and trg batches
"""
raise NotImplementedError("must be implemented by subclasses")
[docs]class InOrderBatcher(Batcher, Serializable):
"""
A class to create batches in order of the original corpus, both across and within batches.
Args:
batch_size: batch size
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!InOrderBatcher"
@serializable_init
def __init__(self,
batch_size: numbers.Integral = 1,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(batch_size, pad_src_to_multiple=pad_src_to_multiple, sort_within_by_trg_len=False)
[docs] def pack(self, src: Sequence[sent.Sentence], trg: Optional[Sequence[sent.Sentence]]) \
-> Tuple[Sequence[Batch], Sequence[Batch]]:
"""
Pack batches. Unlike other batches, the trg sentences are optional.
Args:
src: list of src-side inputs
trg: optional list of trg-side inputs
Returns:
src batches if trg was not given; tuple of src batches and trg batches if trg was given
"""
order = list(range(len(src)))
return self._pack_by_order(src, trg, order)
[docs]class ShuffleBatcher(Batcher):
"""
A template class to create batches through randomly shuffling without sorting.
Sentences inside each batch are sorted by reverse trg length.
Args:
batch_size: batch size
granularity: 'sent' or 'word'
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
def __init__(self,
batch_size: numbers.Integral,
granularity: str = 'sent',
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(batch_size=batch_size, granularity=granularity, pad_src_to_multiple=pad_src_to_multiple,
sort_within_by_trg_len=True)
[docs] def pack(self, src: Sequence[sent.Sentence], trg: Optional[Sequence[sent.Sentence]]) \
-> Tuple[Sequence[Batch], Sequence[Batch]]:
order = list(range(len(src)))
np.random.shuffle(order)
return self._pack_by_order(src, trg, order)
[docs] def is_random(self) -> bool:
return True
[docs]class SortBatcher(Batcher):
"""
A template class to create batches through bucketing sentence length.
Sentences inside each batch are sorted by reverse trg length.
Args:
batch_size: batch size
granularity: 'sent' or 'word'
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
__tiebreaker_eps = 1.0e-7
def __init__(self,
batch_size: numbers.Integral,
granularity: str = 'sent',
sort_key: Callable = lambda x: x[0].sent_len(),
break_ties_randomly: bool=True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(batch_size, granularity=granularity,
pad_src_to_multiple=pad_src_to_multiple,
sort_within_by_trg_len=True)
self.sort_key = sort_key
self.break_ties_randomly = break_ties_randomly
[docs] def pack(self, src: Sequence[sent.Sentence], trg: Optional[Sequence[sent.Sentence]]) \
-> Tuple[Sequence[Batch], Sequence[Batch]]:
if self.break_ties_randomly:
order = np.argsort([self.sort_key(x) + random.uniform(-SortBatcher.__tiebreaker_eps, SortBatcher.__tiebreaker_eps) for x in zip(src,trg)])
else:
order = np.argsort([self.sort_key(x) for x in zip(src,trg)])
return self._pack_by_order(src, trg, order)
[docs] def is_random(self) -> bool:
return self.break_ties_randomly
# Module level functions
[docs]def mark_as_batch(data: Sequence, mask: Optional[Mask] = None) -> Batch:
"""
Mark a sequence of items as batch
Args:
data: sequence of things
mask: optional mask
Returns: a batch of things
"""
if isinstance(data, Batch) and mask is None:
ret = data
else:
ret = ListBatch(data, mask)
return ret
[docs]def is_batched(data: Sequence) -> bool:
"""
Check whether some data is batched.
Args:
data: data to check
Returns:
True iff data is batched.
"""
return isinstance(data, Batch)
[docs]def pad(batch: Sequence, pad_to_multiple: numbers.Integral = 1) -> Batch:
"""
Apply padding to sentences in a batch.
Args:
batch: batch of sentences
pad_to_multiple: pad sentences so their length is a multiple of this integer.
Returns:
batch containing padded items and a corresponding batch mask.
"""
if isinstance(list(batch)[0], sent.CompoundSentence):
ret = []
for compound_i in range(len(batch[0].sents)):
ret.append(
pad(tuple(inp.sents[compound_i] for inp in batch), pad_to_multiple=pad_to_multiple))
return CompoundBatch(*ret)
max_len = max(_len_or_zero(item) for item in batch)
if max_len % pad_to_multiple != 0:
max_len += pad_to_multiple - (max_len % pad_to_multiple)
min_len = min(_len_or_zero(item) for item in batch)
if min_len == max_len:
return ListBatch(batch, mask=None)
masks = np.zeros([len(batch), max_len])
for i, v in enumerate(batch):
for j in range(_len_or_zero(v), max_len):
masks[i,j] = 1.0
padded_items = [item.create_padded_sent(max_len - item.sent_len()) for item in batch]
return ListBatch(padded_items, mask=Mask(masks))
def _len_or_zero(val):
return val.sent_len() if hasattr(val, 'sent_len') else len(val) if hasattr(val, '__len__') else 0
[docs]class SrcBatcher(SortBatcher, Serializable):
"""
A batcher that creates fixed-size batches, grouped by src len.
Sentences inside each batch are sorted by reverse trg length.
Args:
batch_size: batch size
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!SrcBatcher"
@serializable_init
def __init__(self,
batch_size: numbers.Integral,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(batch_size, sort_key=lambda x: x[0].sent_len(), granularity='sent',
break_ties_randomly=break_ties_randomly, pad_src_to_multiple=pad_src_to_multiple)
[docs]class TrgBatcher(SortBatcher, Serializable):
"""
A batcher that creates fixed-size batches, grouped by trg len.
Sentences inside each batch are sorted by reverse trg length.
Args:
batch_size: batch size
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!TrgBatcher"
@serializable_init
def __init__(self,
batch_size: numbers.Integral,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(batch_size, sort_key=lambda x: x[1].sent_len(), granularity='sent',
break_ties_randomly=break_ties_randomly,
pad_src_to_multiple=pad_src_to_multiple)
[docs]class SrcTrgBatcher(SortBatcher, Serializable):
"""
A batcher that creates fixed-size batches, grouped by src len, then trg len.
Sentences inside each batch are sorted by reverse trg length.
Args:
batch_size: batch size
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!SrcTrgBatcher"
@serializable_init
def __init__(self,
batch_size: numbers.Integral,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(batch_size, sort_key=lambda x: x[0].sent_len() + 1.0e-6 * len(x[1]),
granularity='sent', break_ties_randomly=break_ties_randomly,
pad_src_to_multiple=pad_src_to_multiple)
[docs]class TrgSrcBatcher(SortBatcher, Serializable):
"""
A batcher that creates fixed-size batches, grouped by trg len, then src len.
Sentences inside each batch are sorted by reverse trg length.
Args:
batch_size: batch size
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!TrgSrcBatcher"
@serializable_init
def __init__(self,
batch_size: numbers.Integral,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(batch_size, sort_key=lambda x: x[1].sent_len() + 1.0e-6 * len(x[0]),
granularity='sent',
break_ties_randomly=break_ties_randomly,
pad_src_to_multiple=pad_src_to_multiple)
[docs]class SentShuffleBatcher(ShuffleBatcher, Serializable):
"""
A batcher that creates fixed-size batches of random order.
Sentences inside each batch are sorted by reverse trg length.
Args:
batch_size: batch size
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!SentShuffleBatcher"
@serializable_init
def __init__(self, batch_size: numbers.Integral, pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(batch_size, granularity='sent', pad_src_to_multiple=pad_src_to_multiple)
[docs]class WordShuffleBatcher(ShuffleBatcher, Serializable):
"""
A batcher that creates fixed-size batches, grouped by src len.
Sentences inside each batch are sorted by reverse trg length.
Args:
words_per_batch: number of src+trg words in each batch
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!WordShuffleBatcher"
@serializable_init
def __init__(self, words_per_batch: numbers.Integral, pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(words_per_batch, granularity='word', pad_src_to_multiple=pad_src_to_multiple)
[docs]class WordSortBatcher(SortBatcher):
"""
Base class for word sort-based batchers.
Sentences inside each batch are sorted by reverse trg length.
Args:
words_per_batch: number of src+trg words in each batch
avg_batch_size: avg number of sentences in each batch (if words_per_batch not given)
sort_key:
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
def __init__(self,
words_per_batch: Optional[numbers.Integral],
avg_batch_size: Optional[numbers.Real],
sort_key: Callable,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
# Sanity checks
if words_per_batch and avg_batch_size:
raise ValueError("words_per_batch and avg_batch_size are mutually exclusive.")
elif words_per_batch is None and avg_batch_size is None:
raise ValueError("either words_per_batch or avg_batch_size must be specified.")
super().__init__(words_per_batch, sort_key=sort_key, granularity='word',
break_ties_randomly=break_ties_randomly,
pad_src_to_multiple=pad_src_to_multiple)
self.avg_batch_size = avg_batch_size
[docs]class WordSrcBatcher(WordSortBatcher, Serializable):
"""
A batcher that creates variable-sized batches with given average (src+trg) words per batch, grouped by src len.
Sentences inside each batch are sorted by reverse trg length.
Args:
words_per_batch: number of src+trg words in each batch
avg_batch_size: avg number of sentences in each batch (if words_per_batch not given)
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!WordSrcBatcher"
@serializable_init
def __init__(self,
words_per_batch: Optional[numbers.Integral] = None,
avg_batch_size: Optional[numbers.Real] = None,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(words_per_batch, avg_batch_size, sort_key=lambda x: x[0].sent_len(),
break_ties_randomly=break_ties_randomly,
pad_src_to_multiple=pad_src_to_multiple)
def _pack_by_order(self, src, trg, order):
if self.avg_batch_size:
self.batch_size = (sum([s.sent_len() for s in src]) + sum([s.sent_len() for s in trg])) / len(src) * self.avg_batch_size
return super()._pack_by_order(src, trg, order)
[docs]class WordTrgBatcher(WordSortBatcher, Serializable):
"""
A batcher that creates variable-sized batches with given average (src+trg) words per batch, grouped by trg len.
Sentences inside each batch are sorted by reverse trg length.
Args:
words_per_batch: number of src+trg words in each batch
avg_batch_size: avg number of sentences in each batch (if words_per_batch not given)
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!WordTrgBatcher"
@serializable_init
def __init__(self,
words_per_batch: Optional[numbers.Integral] = None,
avg_batch_size: Optional[numbers.Real] = None,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(words_per_batch, avg_batch_size, sort_key=lambda x: x[1].sent_len(),
break_ties_randomly=break_ties_randomly,
pad_src_to_multiple=pad_src_to_multiple)
def _pack_by_order(self, src, trg, order):
if self.avg_batch_size:
self.batch_size = (sum([s.sent_len() for s in src]) + sum([s.sent_len() for s in trg])) / len(src) * self.avg_batch_size
return super()._pack_by_order(src, trg, order)
[docs]class WordSrcTrgBatcher(WordSortBatcher, Serializable):
"""
A batcher that creates variable-sized batches with given average number of src + trg words per batch, grouped by src len, then trg len.
Sentences inside each batch are sorted by reverse trg length.
Args:
words_per_batch: number of src+trg words in each batch
avg_batch_size: avg number of sentences in each batch (if words_per_batch not given)
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!WordSrcTrgBatcher"
@serializable_init
def __init__(self,
words_per_batch: Optional[numbers.Integral] = None,
avg_batch_size: Optional[numbers.Real] = None,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(words_per_batch, avg_batch_size, sort_key=lambda x: x[0].sent_len() + 1.0e-6 * x[1].sent_len(),
break_ties_randomly=break_ties_randomly,
pad_src_to_multiple=pad_src_to_multiple)
def _pack_by_order(self, src, trg, order):
if self.avg_batch_size:
self.batch_size = (sum([s.sent_len() for s in src]) + sum([s.sent_len() for s in trg])) / len(src) * self.avg_batch_size
return super()._pack_by_order(src, trg, order)
[docs]class WordTrgSrcBatcher(WordSortBatcher, Serializable):
"""
A batcher that creates variable-sized batches with given average number of src + trg words per batch, grouped by trg len, then src len.
Sentences inside each batch are sorted by reverse trg length.
Args:
words_per_batch: number of src+trg words in each batch
avg_batch_size: avg number of sentences in each batch (if words_per_batch not given)
break_ties_randomly: if True, randomly shuffle sentences of the same src length before creating batches.
pad_src_to_multiple: pad source sentences so its length is multiple of this integer.
"""
yaml_tag = "!WordTrgSrcBatcher"
@serializable_init
def __init__(self,
words_per_batch: Optional[numbers.Integral] = None,
avg_batch_size: Optional[numbers.Real] = None,
break_ties_randomly: bool = True,
pad_src_to_multiple: numbers.Integral = 1) -> None:
super().__init__(words_per_batch, avg_batch_size, sort_key=lambda x: x[1].sent_len() + 1.0e-6 * x[0].sent_len(),
break_ties_randomly=break_ties_randomly,
pad_src_to_multiple=pad_src_to_multiple)
def _pack_by_order(self, src, trg, order):
if self.avg_batch_size:
self.batch_size = (sum([s.sent_len() for s in src]) + sum([s.sent_len() for s in trg])) / len(src) * self.avg_batch_size
return super()._pack_by_order(src, trg, order)
[docs]def truncate_batches(*xl: Union[dy.Expression, Batch, Mask, recurrent.UniLSTMState]) \
-> Sequence[Union[dy.Expression, Batch, Mask, recurrent.UniLSTMState]]:
"""
Truncate a list of batched items so that all items have the batch size of the input with the smallest batch size.
Inputs can be of various types and would usually correspond to a single time step.
Assume that the batch elements with index 0 correspond across the inputs, so that batch elements will be truncated
from the top, i.e. starting with the highest-indexed batch elements.
Masks are not considered even if attached to a input of :class:`Batch` type.
Args:
*xl: batched timesteps of various types
Returns:
Copies of the inputs, truncated to consistent batch size.
"""
batch_sizes = []
for x in xl:
if isinstance(x, dy.Expression) or isinstance(x, expression_seqs.ExpressionSequence):
batch_sizes.append(x.dim()[1])
elif isinstance(x, Batch):
batch_sizes.append(len(x))
elif isinstance(x, Mask):
batch_sizes.append(x.batch_size())
elif isinstance(x, recurrent.UniLSTMState):
batch_sizes.append(x.output().dim()[1])
else:
raise ValueError(f"unsupported type {type(x)}")
assert batch_sizes[-1] > 0
ret = []
for i, x in enumerate(xl):
if batch_sizes[i] > min(batch_sizes):
if isinstance(x, dy.Expression) or isinstance(x, expression_seqs.ExpressionSequence):
ret.append(x[tuple([slice(None)]*len(x.dim()[0]) + [slice(min(batch_sizes))])])
elif isinstance(x, Batch):
ret.append(mark_as_batch(x[:min(batch_sizes)]))
elif isinstance(x, Mask):
ret.append(Mask(x.np_arr[:min(batch_sizes)]))
elif isinstance(x, recurrent.UniLSTMState):
ret.append(x[:,:min(batch_sizes)])
else:
raise ValueError(f"unsupported type {type(x)}")
else:
ret.append(x)
return ret