Warning: This project is deprecated. TensorFlow Addons has stopped development,
The project will only be providing minimal maintenance releases until May 2024. See the full
announcement here or on
github.
tfa.seq2seq.GreedyEmbeddingSampler
Stay organized with collections
Save and categorize content based on your preferences.
A inference sampler that takes the maximum from the output distribution.
Inherits From: Sampler
tfa.seq2seq.GreedyEmbeddingSampler(
embedding_fn: Optional[Callable] = None
)
Used in the notebooks
Uses the argmax of the output (treated as logits) and passes the
result through an embedding layer to get the next input.
Args |
embedding_fn
|
A optional callable that takes a vector tensor of ids
(argmax ids). The returned tensor will be passed to the decoder
input. Default to use tf.nn.embedding_lookup .
|
Attributes |
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor. The return value might not
available before the invocation of initialize(), in this case,
ValueError is raised.
|
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType. The return value might not available before the
invocation of initialize().
|
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a TensorShape . The return value might not available
before the invocation of initialize().
|
Methods
initialize
View source
initialize(
embedding, start_tokens=None, end_token=None
)
Initialize the GreedyEmbeddingSampler.
Args |
embedding
|
tensor that contains embedding states matrix. It will be
used to generate generate outputs with start_tokens and end_token .
The embedding will be ignored if the embedding_fn has been provided
at init().
|
start_tokens
|
int32 vector shaped [batch_size] , the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
Returns |
Tuple of two items: (finished, self.start_inputs) .
|
Raises |
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
View source
next_inputs(
time, outputs, state, sample_ids
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
View source
sample(
time, outputs, state
)
sample for GreedyEmbeddingHelper.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-05-25 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-05-25 UTC."],[],[]]