tfp.substrates.numpy.bijectors.real_nvp_default_template
Stay organized with collections
Save and categorize content based on your preferences.
Build a scale-and-shift function using a multi-layer neural network.
tfp.substrates.numpy.bijectors.real_nvp_default_template(
hidden_layers,
shift_only=False,
activation=tf.nn.relu,
name=None,
*args,
**kwargs
)
This will be wrapped in a make_template to ensure the variables are only
created once. It takes the d
-dimensional input x[0:d] and returns the D-d
dimensional outputs loc
('mu') and log_scale
('alpha').
The default template does not support conditioning and will raise an
exception if condition_kwargs
are passed to it. To use conditioning in
Real NVP bijector, implement a conditioned shift/scale template that
handles the condition_kwargs
.
Args |
hidden_layers
|
Python list -like of non-negative integer, scalars
indicating the number of units in each hidden layer. Default: [512,
512] .
|
shift_only
|
Python bool indicating if only the shift term shall be
computed (i.e. NICE bijector). Default: False .
|
activation
|
Activation function (callable). Explicitly setting to None
implies a linear activation.
|
name
|
A name for ops managed by this function. Default:
'real_nvp_default_template'.
|
*args
|
tf.layers.dense arguments.
|
**kwargs
|
tf.layers.dense keyword arguments.
|
Returns |
shift
|
Float -like Tensor of shift terms ('mu' in
[Papamakarios et al. (2016)][1]).
|
log_scale
|
Float -like Tensor of log(scale) terms ('alpha' in
[Papamakarios et al. (2016)][1]).
|
Raises |
NotImplementedError
|
if rightmost dimension of inputs is unknown prior to
graph execution, or if condition_kwargs is not empty.
|
References
[1]: George Papamakarios, Theo Pavlakou, and Iain Murray. Masked
Autoregressive Flow for Density Estimation. In Neural Information
Processing Systems, 2017. https://arxiv.org/abs/1705.07057
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[]]