tf_agents.policies.ActorPolicy
Stay organized with collections
Save and categorize content based on your preferences.
Class to build Actor Policies.
Inherits From: TFPolicy
tf_agents.policies.ActorPolicy(
time_step_spec: tf_agents.trajectories.TimeStep
,
action_spec: tf_agents.typing.types.NestedTensorSpec
,
actor_network: tf_agents.networks.Network
,
policy_state_spec: tf_agents.typing.types.NestedTensorSpec
= (),
info_spec: tf_agents.typing.types.NestedTensorSpec
= (),
observation_normalizer: Optional[tf_agents.utils.tensor_normalizer.TensorNormalizer
] = None,
clip: bool = True,
training: bool = False,
observation_and_action_constraint_splitter: Optional[types.Splitter] = None,
name: Optional[Text] = None
)
Used in the notebooks
Args |
time_step_spec
|
A TimeStep spec of the expected time_steps.
|
action_spec
|
A nest of BoundedTensorSpec representing the actions.
|
actor_network
|
An instance of a tf_agents.networks.network.Network to be
used by the policy. The network will be called with call(observation,
step_type, policy_state) and should return (actions_or_distributions,
new_state) .
|
policy_state_spec
|
A nest of TensorSpec representing the policy_state. If
not set, defaults to actor_network.state_spec.
|
info_spec
|
A nest of TensorSpec representing the policy info.
|
observation_normalizer
|
An object to use for observation normalization.
|
clip
|
Whether to clip actions to spec before returning them. Default True.
Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped
continuous actions for training.
|
training
|
Whether the network should be called in training mode.
|
observation_and_action_constraint_splitter
|
A function used to process
observations with action constraints. These constraints can indicate,
for example, a mask of valid/invalid actions for a given state of the
environment. The function takes in a full observation and returns a
tuple consisting of 1) the part of the observation intended as input to
the network and 2) the constraint. An example
observation_and_action_constraint_splitter could be as simple as: def observation_and_action_constraint_splitter(observation): return
observation['network_input'], observation['constraint'] Note: when
using observation_and_action_constraint_splitter , make sure the
provided actor_network is compatible with the network-specific half of
the output of the observation_and_action_constraint_splitter . In
particular, observation_and_action_constraint_splitter will be called
on the observation before passing to the network. If
observation_and_action_constraint_splitter is None, action constraints
are not applied.
|
name
|
The name of this policy. All variables in this module will fall
under that name. Defaults to the class name.
|
Raises |
ValueError
|
if actor_network is not of type network.Network .
|
NotImplementedError
|
if observation_and_action_constraint_splitter is
not None but action_spec is not discrete.
|
Attributes |
action_spec
|
Describes the TensorSpecs of the Tensors expected by step(action) .
action can be a single Tensor, or a nested dict, list or tuple of
Tensors.
|
collect_data_spec
|
Describes the Tensors written when using this policy with an environment.
|
emit_log_probability
|
Whether this policy instance emits log probabilities or not.
|
info_spec
|
Describes the Tensors emitted as info by action and distribution .
info can be an empty tuple, a single Tensor, or a nested dict,
list or tuple of Tensors.
|
observation_and_action_constraint_splitter
|
|
observation_normalizer
|
|
policy_state_spec
|
Describes the Tensors expected by step(_, policy_state) .
policy_state can be an empty tuple, a single Tensor, or a nested dict,
list or tuple of Tensors.
|
policy_step_spec
|
Describes the output of action() .
|
time_step_spec
|
Describes the TimeStep tensors returned by step() .
|
trajectory_spec
|
Describes the Tensors written when using this policy with an environment.
|
validate_args
|
Whether action & distribution validate input and output args.
|
Methods
action
View source
action(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedTensor
= (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
Args |
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
|
seed
|
Seed to use if action performs sampling (optional).
|
Returns |
A PolicyStep named tuple containing:
action : An action Tensor matching the action_spec .
state : A policy state tensor to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
Raises |
RuntimeError
|
If subclass init didn't call super().init.
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec , policy_state_spec ,
or policy_step_spec .
|
distribution
View source
distribution(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedTensor
= ()
) -> tf_agents.trajectories.PolicyStep
Generates the distribution over next actions given the time_step.
Args |
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
|
Returns |
A PolicyStep named tuple containing:
action : A tf.distribution capturing the distribution of next actions.
state : A policy state tensor for the next call to distribution.
info : Optional side information such as action log probabilities.
|
Raises |
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec , policy_state_spec ,
or policy_step_spec .
|
get_initial_state
View source
get_initial_state(
batch_size: Optional[types.Int]
) -> tf_agents.typing.types.NestedTensor
Returns an initial state usable by the policy.
Args |
batch_size
|
Tensor or constant: size of the batch dimension. Can be None
in which case no dimensions gets added.
|
Returns |
A nested object of type policy_state containing properly
initialized Tensors.
|
update
View source
update(
policy,
tau: float = 1.0,
tau_non_trainable: Optional[float] = None,
sort_variables_by_name: bool = False
) -> tf.Operation
Update the current policy with another policy.
This would include copying the variables from the other policy.
Args |
policy
|
Another policy it can update from.
|
tau
|
A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard
update. This is used for trainable variables.
|
tau_non_trainable
|
A float scalar in [0, 1] for non_trainable variables.
If None, will copy from tau.
|
sort_variables_by_name
|
A bool, when True would sort the variables by name
before doing the update.
|
Returns |
An TF op to do the update.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]