View source on GitHub |
Task object for tagging (e.g., NER or POS).
Inherits From: Task
tfm.nlp.tasks.TaggingTask(
params, logging_dir: Optional[str] = None, name: Optional[str] = None
)
Attributes | |
---|---|
logging_dir
|
|
task_config
|
Methods
aggregate_logs
aggregate_logs(
state=None, step_outputs=None
)
Aggregates over logs returned from a validation step.
build_inputs
build_inputs(
params: tfm.core.config_definitions.DataConfig
,
input_context=None
)
Returns tf.data.Dataset for sentence_prediction task.
build_losses
build_losses(
labels, model_outputs, aux_losses=None
) -> tf.Tensor
Standard interface to compute losses.
Args | |
---|---|
labels
|
optional label tensors. |
model_outputs
|
a nested structure of output tensors. |
aux_losses
|
auxiliary loss tensors, i.e. losses in keras.Model.
|
Returns | |
---|---|
The total loss tensor. |
build_metrics
build_metrics(
training: bool = True
)
Gets streaming metrics for training/validation.
build_model
build_model()
[Optional] Creates model architecture.
Returns | |
---|---|
A model instance. |
create_optimizer
@classmethod
create_optimizer( optimizer_config:
tfm.optimization.OptimizationConfig
, runtime_config: Optional[tfm.core.base_task.RuntimeConfig
] = None, dp_config: Optional[tfm.core.base_task.DifferentialPrivacyConfig
] = None )
Creates an TF optimizer from configurations.
Args | |
---|---|
optimizer_config
|
the parameters of the Optimization settings. |
runtime_config
|
the parameters of the runtime. |
dp_config
|
the parameter of differential privacy. |
Returns | |
---|---|
A tf.optimizers.Optimizer object. |
inference_step
inference_step(
inputs, model: tf.keras.Model
)
Performs the forward step.
initialize
initialize(
model: tf.keras.Model
)
[Optional] A callback function used as CheckpointManager's init_fn.
This function will be called when no checkpoint is found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. You can use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.
Args | |
---|---|
model
|
The keras.Model built or used by this task. |
process_compiled_metrics
process_compiled_metrics(
compiled_metrics, labels, model_outputs
)
Process and update compiled_metrics.
call when using compile/fit API.
Args | |
---|---|
compiled_metrics
|
the compiled metrics (model.compiled_metrics). |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
process_metrics
process_metrics(
metrics, labels, model_outputs, **kwargs
)
Process and update metrics.
Called when using custom training loop API.
Args | |
---|---|
metrics
|
a nested structure of metrics objects. The return of function self.build_metrics. |
labels
|
a tensor or a nested structure of tensors. |
model_outputs
|
a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model. |
**kwargs
|
other args. |
reduce_aggregated_logs
reduce_aggregated_logs(
aggregated_logs, global_step=None
)
Reduces aggregated logs over validation steps.
train_step
train_step(
inputs,
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics=None
)
Does forward and backward.
With distribution strategies, this method runs on devices.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the model, forward pass definition. |
optimizer
|
the optimizer for this training step. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
validation_step
validation_step(
inputs, model: tf.keras.Model, metrics=None
)
Validatation step.
Args | |
---|---|
inputs
|
a dictionary of input tensors. |
model
|
the keras.Model. |
metrics
|
a nested structure of metrics objects. |
Returns | |
---|---|
A dictionary of logs. |
Class Variables | |
---|---|
loss |
'loss'
|