tf_agents.policies.SavedModelPyTFEagerPolicy
Stay organized with collections
Save and categorize content based on your preferences.
Exposes a numpy API for saved_model policies in Eager mode.
Inherits From: PyTFEagerPolicyBase
, PyPolicy
tf_agents.policies.SavedModelPyTFEagerPolicy(
model_path: Text,
time_step_spec: Optional[tf_agents.trajectories.TimeStep
] = None,
action_spec: Optional[tf_agents.typing.types.DistributionSpecV2
] = None,
policy_state_spec: tf_agents.typing.types.NestedTensorSpec
= (),
info_spec: tf_agents.typing.types.NestedTensorSpec
= (),
load_specs_from_pbtxt: bool = False,
use_tf_function: bool = False,
batch_time_steps=True
)
Used in the notebooks
Args |
model_path
|
Path to a saved_model generated by the policy_saver .
|
time_step_spec
|
Optional nested structure of ArraySpecs describing the
policy's time_step_spec . This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
|
action_spec
|
Optional nested structure of ArraySpecs describing the
policy's action_spec . This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
|
policy_state_spec
|
Optional nested structure of ArraySpecs describing
the policy's policy_state_spec . This is not used by the
SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is
part of the public policy API.
|
info_spec
|
Optional nested structure of ArraySpecs describing the
policy's info_spec . This is not used by the SavedModelPyTFEagerPolicy,
but may be accessed by other objects as it is part of the public policy
API.
|
load_specs_from_pbtxt
|
If True the specs will be loaded from the proto
file generated by the policy_saver .
|
use_tf_function
|
See PyTFEagerPolicyBase.
|
batch_time_steps
|
See PyTFEagerPolicyBase.
|
Attributes |
action_spec
|
Describes the ArraySpecs of the np.Array returned by action() .
action can be a single np.Array, or a nested dict, list or tuple of
np.Array.
|
collect_data_spec
|
Describes the data collected when using this policy with an environment.
|
info_spec
|
Describes the Arrays emitted as info by action() .
|
observation_and_action_constraint_splitter
|
|
policy_state_spec
|
Describes the arrays expected by functions with policy_state as input.
|
policy_step_spec
|
Describes the output of action() .
|
time_step_spec
|
Describes the TimeStep np.Arrays expected by action(time_step) .
|
trajectory_spec
|
Describes the data collected when using this policy with an environment.
|
Methods
action
View source
action(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
Args |
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
An optional previous policy_state.
|
seed
|
Seed to use if action uses sampling (optional).
|
Returns |
A PolicyStep named tuple containing:
action : A nest of action Arrays matching the action_spec() .
state : A nest of policy states to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
get_initial_state
View source
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
Args |
batch_size
|
An optional batch size.
|
Returns |
An initial policy state.
|
View source
get_metadata()
Returns the metadata of the saved model.
get_train_step
View source
get_train_step() -> tf_agents.typing.types.Int
Returns the training global step of the saved model.
get_train_step_from_last_restored_checkpoint_path
View source
get_train_step_from_last_restored_checkpoint_path() -> Optional[int]
Returns the training step of the restored checkpoint.
update_from_checkpoint
View source
update_from_checkpoint(
checkpoint_path: Text
)
Allows users to update saved_model variables directly from a checkpoint.
checkpoint_path
is a path that was passed to either PolicySaver.save()
or PolicySaver.save_checkpoint()
. The policy looks for set of checkpoint
files with the file prefix `/variables/variables'
Args |
checkpoint_path
|
Path to the checkpoint to restore and use to udpate this
policy.
|
variables
View source
variables()
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf_agents.policies.SavedModelPyTFEagerPolicy\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/policies/py_tf_eager_policy.py#L149-L293) |\n\nExposes a numpy API for saved_model policies in Eager mode.\n\nInherits From: [`PyTFEagerPolicyBase`](../../tf_agents/policies/py_tf_eager_policy/PyTFEagerPolicyBase), [`PyPolicy`](../../tf_agents/policies/py_policy/PyPolicy)\n\n#### View aliases\n\n\n**Main aliases**\n\n[`tf_agents.policies.py_tf_eager_policy.SavedModelPyTFEagerPolicy`](https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/SavedModelPyTFEagerPolicy)\n\n\u003cbr /\u003e\n\n tf_agents.policies.SavedModelPyTFEagerPolicy(\n model_path: Text,\n time_step_spec: Optional[../../tf_agents/trajectories/TimeStep] = None,\n action_spec: Optional[../../tf_agents/typing/types#DistributionSpecV2] = None,\n policy_state_spec: ../../tf_agents/typing/types/NestedTensorSpec = (),\n info_spec: ../../tf_agents/typing/types/NestedTensorSpec = (),\n load_specs_from_pbtxt: bool = False,\n use_tf_function: bool = False,\n batch_time_steps=True\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|--------------------------------------------------------------------------------------------------------------------|\n| - [Checkpointer and PolicySaver](https://www.tensorflow.org/agents/tutorials/10_checkpointer_policysaver_tutorial) |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `model_path` | Path to a saved_model generated by the `policy_saver`. |\n| `time_step_spec` | Optional nested structure of ArraySpecs describing the policy's `time_step_spec`. This is not used by the SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is part of the public policy API. |\n| `action_spec` | Optional nested structure of `ArraySpecs` describing the policy's `action_spec`. This is not used by the SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is part of the public policy API. |\n| `policy_state_spec` | Optional nested structure of `ArraySpecs` describing the policy's `policy_state_spec`. This is not used by the SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is part of the public policy API. |\n| `info_spec` | Optional nested structure of `ArraySpecs` describing the policy's `info_spec`. This is not used by the SavedModelPyTFEagerPolicy, but may be accessed by other objects as it is part of the public policy API. |\n| `load_specs_from_pbtxt` | If True the specs will be loaded from the proto file generated by the `policy_saver`. |\n| `use_tf_function` | See PyTFEagerPolicyBase. |\n| `batch_time_steps` | See PyTFEagerPolicyBase. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|----------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `action_spec` | Describes the ArraySpecs of the np.Array returned by `action()`. \u003cbr /\u003e `action` can be a single np.Array, or a nested dict, list or tuple of np.Array. |\n| `collect_data_spec` | Describes the data collected when using this policy with an environment. |\n| `info_spec` | Describes the Arrays emitted as info by `action()`. |\n| `observation_and_action_constraint_splitter` | \u003cbr /\u003e \u003cbr /\u003e |\n| `policy_state_spec` | Describes the arrays expected by functions with `policy_state` as input. |\n| `policy_step_spec` | Describes the output of `action()`. |\n| `time_step_spec` | Describes the `TimeStep` np.Arrays expected by `action(time_step)`. |\n| `trajectory_spec` | Describes the data collected when using this policy with an environment. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `action`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/policies/py_policy.py#L147-L169) \n\n action(\n time_step: ../../tf_agents/trajectories/TimeStep,\n policy_state: ../../tf_agents/typing/types/NestedArray = (),\n seed: Optional[types.Seed] = None\n ) -\u003e ../../tf_agents/trajectories/PolicyStep\n\nGenerates next action given the time_step and policy_state.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|----------------|---------------------------------------------------------|\n| `time_step` | A `TimeStep` tuple corresponding to `time_step_spec()`. |\n| `policy_state` | An optional previous policy_state. |\n| `seed` | Seed to use if action uses sampling (optional). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A PolicyStep named tuple containing: `action`: A nest of action Arrays matching the `action_spec()`. `state`: A nest of policy states to be fed into the next call to action. `info`: Optional side information such as action log probabilities. ||\n\n\u003cbr /\u003e\n\n### `get_initial_state`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/policies/py_policy.py#L134-L145) \n\n get_initial_state(\n batch_size: Optional[int] = None\n ) -\u003e ../../tf_agents/typing/types/NestedArray\n\nReturns an initial state usable by the policy.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|--------------|-------------------------|\n| `batch_size` | An optional batch size. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| An initial policy state. ||\n\n\u003cbr /\u003e\n\n### `get_metadata`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/policies/py_tf_eager_policy.py#L246-L248) \n\n get_metadata()\n\nReturns the metadata of the saved model.\n\n### `get_train_step`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/policies/py_tf_eager_policy.py#L238-L240) \n\n get_train_step() -\u003e ../../tf_agents/typing/types/Int\n\nReturns the training global step of the saved model.\n\n### `get_train_step_from_last_restored_checkpoint_path`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/policies/py_tf_eager_policy.py#L242-L244) \n\n get_train_step_from_last_restored_checkpoint_path() -\u003e Optional[int]\n\nReturns the training step of the restored checkpoint.\n\n### `update_from_checkpoint`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/policies/py_tf_eager_policy.py#L253-L289) \n\n update_from_checkpoint(\n checkpoint_path: Text\n )\n\nAllows users to update saved_model variables directly from a checkpoint.\n\n`checkpoint_path` is a path that was passed to either [`PolicySaver.save()`](../../tf_agents/policies/PolicySaver#save)\nor [`PolicySaver.save_checkpoint()`](../../tf_agents/policies/PolicySaver#save_checkpoint). The policy looks for set of checkpoint\nfiles with the file prefix \\`/variables/variables'\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|-------------------|------------------------------------------------------------------|\n| `checkpoint_path` | Path to the checkpoint to restore and use to udpate this policy. |\n\n\u003cbr /\u003e\n\n### `variables`\n\n[View source](https://github.com/tensorflow/agents/blob/v0.19.0/tf_agents/policies/py_tf_eager_policy.py#L250-L251) \n\n variables()"]]