View source on GitHub |
Implementation for sampling strategies (go/decoding-tf-nlp).
tfm.nlp.ops.SamplingModule(
symbols_to_logits_fn,
vocab_size: int,
max_decode_length: int,
eos_id: int,
padded_decode: bool,
length_normalization_fn: Optional[Callable[[int, tf.DType], float]] = None,
top_k=0,
top_p=1.0,
sample_temperature=0.0,
enable_greedy: bool = True,
dtype: tf.DType = tf.float32,
decoding_name: Optional[str] = None,
extra_cache_output: bool = False
)
Methods
generate
generate(
initial_ids: tf.Tensor,
initial_cache: Dict[str, tf.Tensor],
initial_log_probs: Optional[tf.Tensor] = None
) -> Output
Implements the decoding strategy (beam_search or sampling).
Args | |
---|---|
initial_ids
|
initial ids to pass into the symbols_to_logits_fn. int tensor with shape [batch_size, 1] |
initial_cache
|
dictionary for caching model outputs from previous step. |
initial_log_probs
|
Optionally initial log probs if there is a prefix sequence we want to start to decode from. |
Returns | |
---|---|
Tuple of tensors representing finished_sequence: shape [batch, max_seq_length] finished_scores: [batch] first_cache: The cache after init token |
inf
inf()
Returns a value close to infinity, but is still finite in dtype
.
This is useful to get a very large value that is still zero when multiplied by zero. The floating-point "Inf" value is NaN when multiplied by zero.
Returns | |
---|---|
A very large value. |