tfr.extension.task.RankingTask

Task object for TF-Ranking.

params the task configuration instance, which can be any of dataclass, ConfigDict, namedtuple, etc.
logging_dir a string pointing to where the model, summaries etc. will be saved. You can also write additional stuff in this directory.
name the task name.

logging_dir

name Returns the name of this module as passed or determined in the ctor.

name_scope Returns a tf.name_scope instance for this class.
non_trainable_variables Sequence of non-trainable variables owned by this module and its submodules.
submodules Sequence of all sub-modules.

Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).

a = tf.Module()
b = tf.Module()
c = tf.Module()
a.b = b
b.c = c
list(a.submodules) == [b, c]
True
list(b.submodules) == [c]
True
list(c.submodules) == []
True

task_config

trainable_variables Sequence of trainable variables owned by this module and its submodules.

variables Sequence of variables owned by this module and its submodules.

Methods

aggregate_logs

Optional aggregation over logs returned from a validation step.

Given step_logs from a validation step, this function aggregates the logs after each eval_step() (see eval_reduce() function in official/core/base_trainer.py). It runs on CPU and can be used to aggregate metrics during validation, when there are too many metrics that cannot fit into TPU memory. Note that this may increase latency due to data transfer between TPU and CPU. Also, the step output from a validation step may be a tuple with elements from replicas, and a concatenation of the elements is needed in such case.

Args
state The current state of training, for example, it can be a sequence of metrics.
step_logs Logs from a validation step. Can be a dictionary.

build_inputs

View source

Returns a dataset or a nested structure of dataset functions.

Dataset functions define per-host datasets with the per-replica batch size. With distributed training, this method runs on remote hosts.

Args
params hyperparams to create input pipelines, which can be any of dataclass, ConfigDict, namedtuple, etc.
input_context optional distribution input pipeline context.

Returns
A nested structure of per-replica input functions.

build_losses

View source

Standard interface to compute losses.

Args
labels optional label tensors.
model_outputs a nested structure of output tensors.
aux_losses auxiliary loss tensors, i.e. losses in keras.Model.

Returns
The total loss tensor.

build_metrics

View source

Gets streaming metrics for training/validation.

build_model

View source

[Optional] Creates model architecture.

Returns
A model instance.

create_optimizer

Creates an TF optimizer from configurations.

Args
optimizer_config the parameters of the Optimization settings.
runtime_config the parameters of the runtime.
dp_config the parameter of differential privacy.

Returns
A tf.optimizers.Optimizer object.

inference_step

Performs the forward step.

With distribution strategies, this method runs on devices.

Args
inputs a dictionary of input tensors.
model the keras.Model.

Returns
Model outputs.

initialize

[Optional] A callback function used as CheckpointManager's init_fn.

This function will be called when no checkpoint is found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. You can use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.

Args
model The keras.Model built or used by this task.

process_compiled_metrics

Process and update compiled_metrics.

call when using compile/fit API.

Args
compiled_metrics the compiled metrics (model.compiled_metrics).
labels a tensor or a nested structure of tensors.
model_outputs a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.

process_metrics

View source

Process and update metrics.

Called when using custom training loop API.

Args
metrics a nested structure of metrics objects. The return of function self.build_metrics.
labels a tensor or a nested structure of tensors.
model_outputs a tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.
**kwargs other args.

reduce_aggregated_logs

Optional reduce of aggregated logs over validation steps.

This function reduces aggregated logs at the end of validation, and can be used to compute the final metrics. It runs on CPU and in each eval_end() in base trainer (see eval_end() function in official/core/base_trainer.py).

Args
aggregated_logs Aggregated logs over multiple validation steps.
global_step An optional variable of global step.

Returns
A dictionary of reduced results.

train_step

View source

Does forward and backward.

With distribution strategies, this method runs on devices.

Args
inputs a dictionary of input tensors.
model the model, forward pass definition.
optimizer the optimizer for this training step.
metrics a nested structure of metrics objects.

Returns
A dictionary of logs.

validation_step

View source

Validation step.

With distribution strategies, this method runs on devices.

Args
inputs a dictionary of input tensors.
model the keras.Model.
metrics a nested structure of metrics objects.

Returns
A dictionary of logs.

with_name_scope

Decorator to automatically enter the module name scope.

class MyModule(tf.Module):
  @tf.Module.with_name_scope
  def __call__(self, x):
    if not hasattr(self, 'w'):
      self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
    return tf.matmul(x, self.w)

Using the above module would produce tf.Variables and tf.Tensors whose names included the module name:

mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>

Args
method The method to wrap.

Returns
The original method wrapped such that it enters the module's name scope.

loss 'loss'