Implementation of beam search loop.
tfm.nlp.ops.SequenceBeamSearch(
symbols_to_logits_fn,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode,
dtype=tf.float32,
noise_multiplier: float = 0.0,
decoding_name=None
)
Args |
symbols_to_logits_fn
|
A function to provide logits, which is the interface
to the Transformer model. The passed in arguments are: ids -> A tensor
with shape [batch_size * beam_size, index]. index -> A scalar. cache ->
A nested dictionary of tensors [batch_size * beam_size, ...]. The
function must return a tuple of logits and the updated cache: logits ->
A tensor with shape [batch * beam_size, vocab_size]. updated cache -> A
nested dictionary with the same structure as the input cache.
|
vocab_size
|
An integer, the size of the vocabulary, used for topk
computation.
|
beam_size
|
An integer, number of beams for beam search.
|
alpha
|
A float, defining the strength of length normalization.
|
max_decode_length
|
An integer, the maximum number of steps to decode a
sequence.
|
eos_id
|
An integer or a list. ID of end of sentence token.
|
padded_decode
|
A bool, indicating if max_sequence_length padding is used
for beam search.
|
dtype
|
A tensorflow data type used for score computation. The default is
tf.float32.
|
noise_multiplier
|
The amount of noise.
|
decoding_name
|
an optional name for the decoding loop tensors.
|
Methods
search
View source
search(
initial_ids, initial_cache
)
Beam search for sequences with highest scores.
Args |
initial_ids
|
initial ids to pass into the symbols_to_logits_fn. int tensor
with shape [batch_size, 1]
|
initial_cache
|
dictionary storing values to be passed into the
symbols_to_logits_fn.
|
Returns |
finished_seq and finished_scores.
|