tfm.core.train_lib.OrbitExperimentRunner
Stay organized with collections
Save and categorize content based on your preferences.
Runs experiment with Orbit training loop.
tfm.core.train_lib.OrbitExperimentRunner(
distribution_strategy: tf.distribute.Strategy,
task: tfm.core.base_task.Task
,
mode: str,
params: tfm.core.base_trainer.ExperimentConfig
,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
trainer: Optional[tfm.core.base_trainer.Trainer
] = None,
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False
)
The default experiment runner for model garden experiments. User can
customize the experiment pipeline by subclassing this class and replacing
components or functions.
For example, an experiment runner with customized checkpoint manager:
class MyExpRunnerWithExporter(OrbitExperimentRunner):
def _maybe_build_checkpoint_manager(sefl):
# Replaces the default CheckpointManger with a customized one.
return MyCheckpointManager(*args)
# In user code, instead of the orginal
# `OrbitExperimentRunner(..).run(mode)`, now user can do:
MyExpRunnerWithExporter(**needed_kwargs).run(mode)
Similar override can be done to other components.
Args |
distribution_strategy
|
A distribution strategy.
|
task
|
A Task instance.
|
mode
|
A 'str', specifying the mode. Can be 'train', 'eval',
'train_and_eval' or 'continuous_eval'.
|
params
|
ExperimentConfig instance.
|
model_dir
|
A 'str', a path to store model checkpoints and summaries.
|
run_post_eval
|
Whether to run post eval once after training, metrics logs
are returned.
|
save_summary
|
Whether to save train and validation summary.
|
train_actions
|
Optional list of Orbit train actions.
|
eval_actions
|
Optional list of Orbit eval actions.
|
trainer
|
the base_trainer.Trainer instance. It should be created within
the strategy.scope().
|
controller_cls
|
The controller class to manage the train and eval process.
Must be a orbit.Controller subclass.
|
summary_manager
|
Instance of the summary manager to override default
summary manager.
|
eval_summary_manager
|
Instance of the eval summary manager to override
default eval summary manager.
|
enable_async_checkpointing
|
Optional boolean indicating whether to enable
async checkpoint saving.
|
Attributes |
checkpoint_manager
|
The CheckpointManager that stores the checkpoints in a train job.
|
controller
|
The Orbit controller object.
|
model_dir
|
Path to the model folder, which stores checkpoints, params, log, etc.
|
params
|
The whole experiment parameters object.
|
trainer
|
The underlying Orbit Trainer object.
|
Methods
run
View source
run() -> Tuple[tf.keras.Model, Mapping[str, Any]]
Run experiments by mode.
Returns |
A 2-tuple of (model, eval_logs).
model: tf.keras.Model instance.
eval_logs: returns eval metrics logs when run_post_eval is set to True,
otherwise, returns {}.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-02-02 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-02-02 UTC."],[],[]]