tfm.nlp.ops.SequenceBeamSearch
Stay organized with collections
Save and categorize content based on your preferences.
Implementation of beam search loop.
tfm.nlp.ops.SequenceBeamSearch(
symbols_to_logits_fn,
vocab_size,
beam_size,
alpha,
max_decode_length,
eos_id,
padded_decode,
dtype=tf.float32,
noise_multiplier: float = 0.0,
decoding_name=None
)
Args |
symbols_to_logits_fn
|
A function to provide logits, which is the interface
to the Transformer model. The passed in arguments are: ids -> A tensor
with shape [batch_size * beam_size, index]. index -> A scalar. cache ->
A nested dictionary of tensors [batch_size * beam_size, ...]. The
function must return a tuple of logits and the updated cache: logits ->
A tensor with shape [batch * beam_size, vocab_size]. updated cache -> A
nested dictionary with the same structure as the input cache.
|
vocab_size
|
An integer, the size of the vocabulary, used for topk
computation.
|
beam_size
|
An integer, number of beams for beam search.
|
alpha
|
A float, defining the strength of length normalization.
|
max_decode_length
|
An integer, the maximum number of steps to decode a
sequence.
|
eos_id
|
An integer or a list. ID of end of sentence token.
|
padded_decode
|
A bool, indicating if max_sequence_length padding is used
for beam search.
|
dtype
|
A tensorflow data type used for score computation. The default is
tf.float32.
|
noise_multiplier
|
The amount of noise.
|
decoding_name
|
an optional name for the decoding loop tensors.
|
Methods
search
View source
search(
initial_ids, initial_cache
)
Beam search for sequences with highest scores.
Args |
initial_ids
|
initial ids to pass into the symbols_to_logits_fn. int tensor
with shape [batch_size, 1]
|
initial_cache
|
dictionary storing values to be passed into the
symbols_to_logits_fn.
|
Returns |
finished_seq and finished_scores.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-02-02 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-02-02 UTC."],[],[]]