View source on GitHub |
An inference sampler that randomly samples from the output distribution.
Inherits From: GreedyEmbeddingSampler
, Sampler
tfa.seq2seq.SampleEmbeddingSampler(
embedding_fn: Optional[Callable] = None,
softmax_temperature: Optional[TensorLike] = None,
seed: Optional[TensorLike] = None
)
Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.
Raises | |
---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
Methods
initialize
initialize(
embedding, start_tokens=None, end_token=None
)
Initialize the GreedyEmbeddingSampler.
Args | |
---|---|
embedding
|
tensor that contains embedding states matrix. It will be
used to generate generate outputs with start_tokens and end_token .
The embedding will be ignored if the embedding_fn has been provided
at init().
|
start_tokens
|
int32 vector shaped [batch_size] , the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
Returns | |
---|---|
Tuple of two items: (finished, self.start_inputs) .
|
Raises | |
---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
next_inputs
next_inputs(
time, outputs, state, sample_ids
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample(
time, outputs, state
)
sample for SampleEmbeddingHelper.