tfm.nlp.models.T5Transformer
Stay organized with collections
Save and categorize content based on your preferences.
Transformer Encoder+Decoder for sequence to sequence.
tfm.nlp.models.T5Transformer(
config: tfm.nlp.models.T5TransformerParams
,
compute_dtype: tf.DType = tf.float32,
**kwargs
)
Args |
dtype
|
the variable allocation dtype.
|
name
|
a string for the module name.
|
Attributes |
checkpoint_items
|
|
Methods
create_variable
View source
create_variable(
name: Text,
shape: ShapeLike,
initializer: Initializer,
dtype: tf.DType = tf.float32,
**kwargs
)
decode
View source
decode(
encoded,
decoder_target_tokens,
encoder_input_tokens=None,
encoder_dense_inputs=None,
decoder_input_tokens=None,
encoder_segment_ids=None,
encoder_dense_segment_ids=None,
decoder_segment_ids=None,
decode_position=None,
cache=None,
max_decode_len=None,
decode=False,
training=False
) -> Dict[str, tf.Tensor]
encode
View source
encode(
encoder_input_tokens=None,
encoder_segment_ids=None,
encoder_dense_inputs=None,
encoder_dense_segment_ids=None,
training=False
)
read_variable
View source
read_variable(
variable: tf.Variable, as_dtype: Optional[tf.DType] = None
)
__call__
View source
__call__(
encoder_input_tokens=None,
decoder_target_tokens=None,
encoder_dense_inputs=None,
encoder_dense_segment_ids=None,
decoder_input_tokens=None,
encoder_segment_ids=None,
decoder_segment_ids=None,
training=False
)
Applies Transformer model on the inputs.
Args |
encoder_input_tokens
|
input tokens to the encoder.
|
decoder_target_tokens
|
target tokens to the decoder.
|
encoder_dense_inputs
|
input dense vectors to the encoder.
|
encoder_dense_segment_ids
|
dense input segmentation info for packed
|
decoder_input_tokens
|
input tokens to the decoder, only required for
training.
|
encoder_segment_ids
|
input segmentation info for packed examples.
examples.
|
decoder_segment_ids
|
target segmentation info for packed examples.
|
training
|
whether it is training pass, affecting dropouts.
|
Returns |
a dictionary of logits/cache.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-02-02 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-02-02 UTC."],[],[]]