tf_agents.keras_layers.SquashedOuterWrapper
Stay organized with collections
Save and categorize content based on your preferences.
Squash the outer dimensions of input tensors; unsquash outputs.
tf_agents.keras_layers.SquashedOuterWrapper(
wrapped: tf.keras.layers.Layer, inner_rank: int, **kwargs
)
This layer wraps a Keras layer wrapped
that cannot handle more than one
batch dimension. It squashes inputs' outer dimensions to a single larger
batch then unsquashes the outputs of wrapped
.
The outer dimensions are the leftmost rank(inputs) - inner_rank
dimensions.
Examples:
batch_norm = tf.keras.layers.BatchNormalization(axis=-1)
layer = SquashedOuterWrapper(wrapped=batch_norm, inner_rank=3)
inputs_0 = tf.random.normal((B, H, W, C))
# batch_norm sees tensor of shape [B, H, W, C]
# outputs_1 shape is [B, H, W, C]
outputs_0 = layer(inputs_0)
inputs_1 = tf.random.normal((B, T, H, W, C))
# batch_norm sees a tensor of shape [B * T, H, W, C]
# outputs_1 shape is [B, T, H, W, C]
outputs_1 = layer(inputs_1)
inputs_2 = tf.random.normal((B1, B2, T, H, W, C))
# batch_norm sees a tensor of shape [B1 * B2 * T, H, W, C]
# outputs_2 shape is [B1, B2, T, H, W, C]
outputs_2 = layer(inputs_2)
Args |
wrapped
|
The keras layer to wrap.
|
inner_rank
|
The inner rank of inputs that will be passed to the layer.
This value allows us to infer the outer batch dimension regardless of
the input shape to build or call .
|
**kwargs
|
Additional arguments for keras layer construction.
|
Raises |
ValueError
|
If wrapped has method get_initial_state , because
we do not know how to handle the case of multiple inputs and
the presence of this method typically means an RNN or RNN-like
layer which accepts separate state tensors.
|
Attributes |
inner_rank
|
|
wrapped
|
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[]]