from collections import namedtuple
import math
from typing import Callable, List, Optional, Sequence
import numbers
import dynet as dy
import numpy as np
from xnmt import batchers, logger
from xnmt.modelparts import decoders
from xnmt.length_norm import NoNormalization, LengthNormalization
from xnmt.persistence import Serializable, serializable_init, bare
from xnmt.vocabs import Vocab
SearchOutput = namedtuple('SearchOutput', ['word_ids', 'attentions', 'score', 'state', 'mask'])
"""
Output of the search
words_ids: list of generated word ids
attentions: list of corresponding attention vector of word_ids
score: a single value of log(p(E|F))
logsoftmaxes: a corresponding softmax vector of the score. score = logsoftmax[word_id]
state: a NON-BACKPROPAGATEABLE state that is used to produce the logsoftmax layer
state is usually used to generate 'baseline' in reinforce loss
masks: whether the particular word id should be ignored or not (1 for not, 0 for yes)
"""
[docs]class SearchStrategy(object):
"""
A template class to generate translation from the output probability model. (Non-batched operation)
"""
[docs] def generate_output(self,
translator: 'xnmt.models.translators.AutoRegressiveTranslator',
initial_state: decoders.AutoRegressiveDecoderState,
src_length: Optional[numbers.Integral] = None) -> List[SearchOutput]:
"""
Args:
translator: a translator
initial_state: initial decoder state
src_length: length of src sequence, required for some types of length normalization
Returns:
List of (word_ids, attentions, score, logsoftmaxes)
"""
raise NotImplementedError('generate_output must be implemented in SearchStrategy subclasses')
[docs]class GreedySearch(Serializable, SearchStrategy):
"""
Performs greedy search (aka beam search with beam size 1)
Args:
max_len: maximum number of tokens to generate.
"""
yaml_tag = '!GreedySearch'
@serializable_init
def __init__(self, max_len: numbers.Integral = 100) -> None:
self.max_len = max_len
[docs] def generate_output(self,
translator: 'xnmt.models.translators.AutoRegressiveTranslator',
initial_state: decoders.AutoRegressiveDecoderState,
src_length: Optional[numbers.Integral] = None) -> List[SearchOutput]:
# Output variables
score = []
word_ids = []
attentions = []
states = []
masks = []
# Search Variables
done = None
current_state = initial_state
for length in range(self.max_len):
prev_word = word_ids[length-1] if length > 0 else None
current_output = translator.add_input(prev_word, current_state)
word_id, word_score = translator.best_k(current_output.state, 1, normalize_scores=True)
word_id = word_id[0]
word_score = word_score[0]
current_state = current_output.state
if len(word_id.shape) == 0:
word_id = np.array([word_id])
word_score = np.array([word_score])
if done is not None:
word_id = [word_id[i] if not done[i] else Vocab.ES for i in range(len(done))]
mask = [1 if not done[i] else 0 for i in range(len(done))]
word_score = [s * m for (s, m) in zip(word_score, mask)]
masks.append(mask)
# Packing outputs
score.append(word_score)
word_ids.append(word_id)
attentions.append(current_output.attention)
states.append(current_state)
# Check if we are done.
done = [x == Vocab.ES for x in word_id]
if all(done):
break
masks.insert(0, [1 for _ in range(len(done))])
words = np.stack(word_ids, axis=1)
score = np.sum(score, axis=0)
return [SearchOutput(words, attentions, score, states, masks)]
[docs]class BeamSearch(Serializable, SearchStrategy):
"""
Performs beam search.
Args:
beam_size: number of beams
max_len: maximum number of tokens to generate.
len_norm: type of length normalization to apply
one_best: Whether to output the best hyp only or all completed hyps.
scores_proc: apply an optional operation on all scores prior to choosing the top k.
E.g. use with :class:`xnmt.length_normalization.EosBooster`.
"""
yaml_tag = '!BeamSearch'
Hypothesis = namedtuple('Hypothesis', ['score', 'output', 'parent', 'word'])
@serializable_init
def __init__(self,
beam_size: numbers.Integral = 1,
max_len: numbers.Integral = 100,
len_norm: LengthNormalization = bare(NoNormalization),
one_best: bool = True,
scores_proc: Optional[Callable[[np.ndarray], None]] = None) -> None:
self.beam_size = beam_size
self.max_len = max_len
self.len_norm = len_norm
self.one_best = one_best
self.scores_proc = scores_proc
[docs] def generate_output(self,
translator: 'xnmt.models.translators.AutoRegressiveTranslator',
initial_state: decoders.AutoRegressiveDecoderState,
src_length: Optional[numbers.Integral] = None) -> List[SearchOutput]:
active_hyp = [self.Hypothesis(0, None, None, None)]
completed_hyp = []
for length in range(self.max_len):
if len(completed_hyp) >= self.beam_size:
break
# Expand hyp
new_set = []
for hyp in active_hyp:
if length > 0:
prev_word = hyp.word
prev_state = hyp.output.state
else:
prev_word = None
prev_state = initial_state
# We have a complete hyp ending with </s>
if prev_word == Vocab.ES:
completed_hyp.append(hyp)
continue
# Find the k best words at the next time step
current_output = translator.add_input(prev_word, prev_state)
top_words, top_scores = translator.best_k(current_output.state, self.beam_size, normalize_scores=True)
# Queue next states
for cur_word, score in zip(top_words, top_scores):
assert len(score.shape) == 0
new_score = self.len_norm.normalize_partial_topk(hyp.score, score, length + 1)
new_set.append(self.Hypothesis(new_score, current_output, hyp, cur_word))
# Next top hypothesis
active_hyp = sorted(new_set, key=lambda x: x.score, reverse=True)[:self.beam_size]
# There is no hyp that reached </s>
if len(completed_hyp) == 0:
completed_hyp = active_hyp
# Length Normalization
normalized_scores = self.len_norm.normalize_completed(completed_hyp, src_length)
hyp_and_score = sorted(list(zip(completed_hyp, normalized_scores)), key=lambda x: x[1], reverse=True)
# Take only the one best, if that's what was desired
if self.one_best:
hyp_and_score = [hyp_and_score[0]]
# Backtracing + Packing outputs
results = []
for end_hyp, score in hyp_and_score:
word_ids = []
attentions = []
states = []
current = end_hyp
while current.parent is not None:
word_ids.append(current.word)
attentions.append(current.output.attention)
states.append(current.output.state)
current = current.parent
results.append(SearchOutput([list(reversed(word_ids))], [list(reversed(attentions))],
[score], list(reversed(states)), [1 for _ in word_ids]))
return results
[docs]class SamplingSearch(Serializable, SearchStrategy):
"""
Performs search based on the softmax probability distribution.
Similar to greedy searchol
Args:
max_len:
sample_size:
"""
yaml_tag = '!SamplingSearch'
@serializable_init
def __init__(self, max_len: numbers.Integral = 100, sample_size: numbers.Integral = 5) -> None:
self.max_len = max_len
self.sample_size = sample_size
[docs] def generate_output(self,
translator: 'xnmt.models.translators.AutoRegressiveTranslator',
initial_state: decoders.AutoRegressiveDecoderState,
src_length: Optional[numbers.Integral] = None) -> List[SearchOutput]:
outputs = []
for k in range(self.sample_size):
outputs.append(self.sample_one(translator, initial_state))
return outputs
# Words ids, attentions, score, logsoftmax, state
def sample_one(self,
translator: 'xnmt.models.translators.AutoRegressiveTranslator',
initial_state: decoders.AutoRegressiveDecoderState) -> SearchOutput:
# Search variables
current_words = None
current_state = initial_state
done = None
# Outputs
scores = []
samples = []
states = []
attentions = []
masks = []
# Sample to the max length
for length in range(self.max_len):
current_output = translator.add_input(current_words, current_state)
word_id, word_score = translator.sample(current_output.state, 1)[0]
word_score = word_score.npvalue()
assert word_score.shape == (1,)
word_score = word_score[0]
if len(word_id.shape) == 0:
word_id = np.array([word_id])
word_score = np.array([word_score])
if done is not None:
word_id = [word_id[i] if not done[i] else Vocab.ES for i in range(len(done))]
# masking for logsoftmax
mask = [1 if not done[i] else 0 for i in range(len(done))]
word_score = [s * m for (s, m) in zip(word_score, mask)]
masks.append(mask)
# Appending output
scores.append(word_score)
samples.append(word_id)
states.append(current_output.state)
attentions.append(current_output.attention)
# Next time step
current_words = word_id
current_state = current_output.state
# Check done
done = [x == Vocab.ES for x in word_id]
# Check if we are done.
if all(done):
break
# Packing output
scores = [np.sum(scores)]
masks.insert(0, [1 for _ in range(len(done))])
samples = np.stack(samples, axis=1)
return SearchOutput(samples, attentions, scores, states, masks)
class MctsNode(object):
def __init__(self,
parent: Optional['MctsNode'],
prior_dist: np.ndarray,
word: Optional[numbers.Integral],
attention: Optional[List[np.ndarray]],
translator: 'xnmt.models.translators.AutoRegressiveTranslator',
dec_state: decoders.AutoRegressiveDecoderState) -> None:
self.parent = parent
self.prior_dist = prior_dist # log of softmax
self.word = word
self.attention = attention
self.translator = translator
self.dec_state = dec_state
self.tries = 0
self.avg_value = 0.0
self.children = {}
# If the child is unvisited, set its avg_value to
# parent value - reduction where reduction = c * sqrt(sum of scores of all visited children)
# where c is 0.25 in leela
self.reduction = 0.0
def choose_child(self) -> numbers.Integral:
return max(range(len(self.prior_dist)), key=lambda move: self.compute_priority(move))
def compute_priority(self, move: numbers.Integral) -> numbers.Real:
if move not in self.children:
child_val = self.prior_dist[move] + self.avg_value - self.reduction
child_tries = 0
else:
child_val = self.prior_dist[move] + self.children[move].avg_value
child_tries = self.children[move].tries
K = 5.0
exp_term = math.sqrt(1.0 * self.tries + 1.0) / (child_tries + 1)
# TODO: This exp could be done before the prior is passed into the MctsNode
# so it's done as a big batch
exp_term *= K * math.exp(self.prior_dist[move])
total_value = child_val + exp_term
return total_value
def expand(self) -> 'MctsNode':
if self.word == Vocab.ES:
return self
move = self.choose_child()
if move in self.children:
return self.children[move].expand()
else:
output = self.translator.add_input(move, self.dec_state)
prior_dist = self.translator.calc_log_probs(output.state).npvalue()
attention = output.attention
path = []
node = self
while node is not None:
path.append(node.word)
node = node.parent
path = ' '.join(str(word) for word in reversed(path))
print('Creating new node:', path, '+', move)
new_node = MctsNode(self, prior_dist, move, attention,
self.translator, output.state)
self.children[move] = new_node
return new_node
def rollout(self, sample_func, max_len):
prefix = []
scores = []
prev_word = None
dec_state = self.dec_state
if self.word == Vocab.ES:
return prefix, scores
while True:
output = self.translator.add_input(prev_word, dec_state)
logsoftmax = self.translator.calc_log_probs(output.state).npvalue()
attention = output.attention
best_id = sample_func(logsoftmax)
print("Rolling out node with word=", best_id, 'score=', logsoftmax[best_id])
prefix.append(best_id)
scores.append(logsoftmax[best_id])
if best_id == Vocab.ES or len(prefix) >= max_len:
break
prev_word = best_id
dec_state = output.state
return prefix, scores
def backup(self, result):
print('Backing up', result)
self.avg_value = self.avg_value * (self.tries / (self.tries + 1)) + result / (self.tries + 1)
self.tries += 1
if self.parent is not None:
my_prob = self.parent.prior_dist[self.word]
self.parent.backup(result + my_prob)
def collect(self, words, attentions):
if self.word is not None:
words.append(self.word)
attentions.append(self.attention)
if len(self.children) > 0:
best_child = max(self.children.itervalues(), key=lambda child: child.visits)
best_child.collect(words, attentions)
def random_choice(logsoftmax: np.ndarray) -> numbers.Integral:
#logsoftmax *= 100
probs = np.exp(logsoftmax)
probs /= sum(probs)
choices = np.random.choice(len(probs), 1, p=probs)
return choices[0]
def greedy_choice(logsoftmax: np.ndarray) -> numbers.Integral:
return np.argmax(logsoftmax)
[docs]class MctsSearch(Serializable, SearchStrategy):
"""
Performs search with Monte Carlo Tree Search
"""
yaml_tag = '!MctsSearch'
@serializable_init
def __init__(self, visits: numbers.Integral = 200, max_len: numbers. Integral = 100) -> None:
self.max_len = max_len
self.visits = visits
[docs] def generate_output(self,
translator: 'xnmt.models.translators.AutoRegressiveTranslator',
dec_state: decoders.AutoRegressiveDecoderState,
src_length: Optional[numbers.Integral] = None) -> List[SearchOutput]:
orig_dec_state = dec_state
output = translator.add_input(None, dec_state)
dec_state = output.state
assert dec_state == orig_dec_state
logsoftmax = self.translator.calc_log_probs(dec_state).npvalue()
root_node = MctsNode(None, logsoftmax, None, None, translator, dec_state)
for i in range(self.visits):
terminal = root_node.expand()
words, scores = terminal.rollout(random_choice, self.max_len)
terminal.backup(sum(scores))
print()
print('Final stats:')
for word in root_node.children:
print (word, root_node.compute_priority(word), root_node.prior_dist[word] + root_node.children[word].avg_value, root_node.children[word].tries)
print()
scores = []
logsoftmaxes = []
word_ids = []
attentions = []
states = []
masks = []
node = root_node
while True:
if len(node.children) == 0:
break
best_word = max(node.children, key=lambda word: node.children[word].tries)
score = node.prior_dist[best_word]
attention = node.children[best_word].attention
scores.append(score)
logsoftmaxes.append(node.prior_dist)
word_ids.append(best_word)
attentions.append(attention)
states.append(node.dec_state)
masks.append(1)
node = node.children[best_word]
word_ids = np.expand_dims(word_ids, axis=0)
return [SearchOutput(word_ids, attentions, scores, logsoftmaxes, states, masks)]