tf_agents.utils.eager_utils.create_train_step
Stay organized with collections
Save and categorize content based on your preferences.
Creates a train_step that evaluates the gradients and returns the loss.
tf_agents.utils.eager_utils.create_train_step(
loss,
optimizer,
global_step=_USE_GLOBAL_STEP,
total_loss_fn=None,
update_ops=None,
variables_to_train=None,
transform_grads_fn=None,
summarize_gradients=False,
gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP,
aggregation_method=None,
check_numerics=True
)
Args |
loss
|
A (possibly nested tuple of) Tensor or function representing the
loss.
|
optimizer
|
A tf.Optimizer to use for computing the gradients.
|
global_step
|
A Tensor representing the global step variable. If left as
_USE_GLOBAL_STEP , then tf.train.get_or_create_global_step() is used.
|
total_loss_fn
|
Function to call on loss value to access the final item to
minimize.
|
update_ops
|
An optional list of updates to execute. If update_ops is
None , then the update ops are set to the contents of the
tf.GraphKeys.UPDATE_OPS collection. If update_ops is not None , but
it doesn't contain all of the update ops in tf.GraphKeys.UPDATE_OPS , a
warning will be displayed.
|
variables_to_train
|
an optional list of variables to train. If None, it will
default to all tf.trainable_variables().
|
transform_grads_fn
|
A function which takes a single argument, a list of
gradient to variable pairs (tuples), performs any requested gradient
updates, such as gradient clipping or multipliers, and returns the updated
list.
|
summarize_gradients
|
Whether or not add summaries for each gradient.
|
gate_gradients
|
How to gate the computation of gradients. See tf.Optimizer.
|
aggregation_method
|
Specifies the method used to combine gradient terms.
Valid values are defined in the class AggregationMethod .
|
check_numerics
|
Whether or not we apply check_numerics.
|
Returns |
In graph mode: A (possibly nested tuple of) Tensor that when evaluated,
calculates the current loss, computes the gradients, applies the
optimizer, and returns the current loss.
In eager mode: A lambda function that when is called, calculates the loss,
then computes and applies the gradients and returns the original
loss values.
|
Raises |
ValueError
|
if loss is not callable.
|
RuntimeError
|
if resource variables are not enabled.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]