tf_agents.utils.nest_utils.is_batched_nested_tensors
Stay organized with collections
Save and categorize content based on your preferences.
Compares tensors to specs to determine if all tensors are batched or not.
tf_agents.utils.nest_utils.is_batched_nested_tensors(
tensors,
specs,
num_outer_dims=1,
allow_extra_fields=False,
check_dtypes=True
)
For each tensor, it checks the dimensions and dtypes with respect to specs.
Returns True
if all tensors are batched and False
if all tensors are
unbatched.
Raises a ValueError
if the shapes are incompatible or a mix of batched and
unbatched tensors are provided.
Raises a TypeError
if tensors' dtypes do not match specs.
Args |
tensors
|
Nested list/tuple/dict of Tensors.
|
specs
|
Nested list/tuple/dict of Tensors or CompositeTensors describing the
shape of unbatched tensors.
|
num_outer_dims
|
The integer number of dimensions that are considered batch
dimensions. Default 1.
|
allow_extra_fields
|
If True , then tensors may have extra subfields which
are not in specs. In this case, the extra subfields will not be checked.
For example: python tensors = {"a": tf.zeros((3, 4),
dtype=tf.float32), "b": tf.zeros((5, 6), dtype=tf.float32)} specs = {"a":
tf.TensorSpec(shape=(4,), dtype=tf.float32)} assert
is_batched_nested_tensors(tensors, specs, allow_extra_fields=True) The
above example would raise a ValueError if allow_extra_fields was False.
|
check_dtypes
|
If True will validate that tensors and specs have the same
dtypes.
|
Returns |
True if all Tensors are batched and False if all Tensors are unbatched.
|
Raises |
ValueError
|
If
- Any of the tensors or specs have shapes with ndims == None, or
- The shape of Tensors are not compatible with specs, or
- A mix of batched and unbatched tensors are provided.
- The tensors are batched but have an incorrect number of outer dims.
|
TypeError
|
If dtypes between tensors and specs are not compatible.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]