View source on GitHub |
Defines an "action" abstraction for use with orbit.Controller
.
"Actions" are simply arbitrary callables that are applied by the Controller
to the output of train steps (after each inner loop of steps_per_loop
steps)
or an evaluation. This provides a hook mechanism, enabling things like reporting
metrics to Vizier, model exporting, additional logging, etc.
The basic Action
abstraction (just a type alias) is defined in the
controller
module. This actions
module adds a ConditionalAction
utility
class to make it easy to trigger actions conditionally based on reusable
predicates, as well as a small handful of predefined conditions/actions (in
particular, a NewBestMetric
condition and an ExportSavedModel
action).
One example of using actions to do metric-conditional export:
new_best_metric = orbit.actions.NewBestMetric('accuracy')
export_action = orbit.actions.ConditionalAction(
condition=lambda x: x['accuracy'] > 0.9 and new_best_metric(x),
action=orbit.actions.ExportSavedModel(
model,
orbit.actions.ExportFileManager(
base_name=f'{FLAGS.model_dir}/saved_model',
next_id_fn=trainer.global_step.numpy),
signatures=model.infer))
controller = orbit.Controller(
strategy=strategy,
trainer=trainer,
evaluator=evaluator,
eval_actions=[export_action],
global_step=trainer.global_step,
steps_per_loop=FLAGS.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_interval=1000)
client_0_actions = orbit.actions.ConditionalAction(
condition=lambda _: client_id() == 0,
action=[
...
])
In particular, the NewBestMetric
condition may be used in multi-client
settings if all clients are guaranteed to compute the same metric (ensuring this
is up to client code, not Orbit). However, when saving metrics it may be helpful
to avoid unnecessary writes by setting the write_value
parameter to False
for most clients.
Classes
class ConditionalAction
: Represents an action that is only taken when a given condition is met.
class ExportFileManager
: Utility class that manages a group of files with a shared base name.
class ExportSavedModel
: Action that exports the given model as a SavedModel.
class JSONPersistedValue
: Represents a value that is persisted via a file-based backing store.
class NewBestMetric
: Condition that is satisfied when a new best metric is achieved.
class SaveCheckpointIfPreempted
: Action that saves on-demand checkpoints after a preemption.