tfp.layers.default_mean_field_normal_fn
Stay organized with collections
Save and categorize content based on your preferences.
Creates a function to build Normal distributions with trainable params.
tfp.layers.default_mean_field_normal_fn(
is_singular=False,
loc_initializer=tf1.initializers.random_normal(stddev=0.1),
untransformed_scale_initializer=tf1.initializers.random_normal(mean=-3.0, stddev=0.1),
loc_regularizer=None,
untransformed_scale_regularizer=None,
loc_constraint=None,
untransformed_scale_constraint=None
)
This function produces a closure which produces tfd.Normal
parameterized by a loc
and scale
each created using tf.get_variable
.
Args |
is_singular
|
Python bool if True , forces the special case limit of
scale->0 , i.e., a Deterministic distribution.
|
loc_initializer
|
Initializer function for the loc parameters.
The default is tf.random_normal_initializer(mean=0., stddev=0.1) .
|
untransformed_scale_initializer
|
Initializer function for the scale
parameters. Default value: tf.random_normal_initializer(mean=-3.,
stddev=0.1) . This implies the softplus transformed result is initialized
near 0 . It allows a Normal distribution with scale parameter set to
this value to approximately act like a point mass.
|
loc_regularizer
|
Regularizer function for the loc parameters.
|
untransformed_scale_regularizer
|
Regularizer function for the scale
parameters.
|
loc_constraint
|
An optional projection function to be applied to the
loc after being updated by an Optimizer . The function must take as input
the unprojected variable and must return the projected variable (which
must have the same shape). Constraints are not safe to use when doing
asynchronous distributed training.
|
untransformed_scale_constraint
|
An optional projection function to be
applied to the scale parameters after being updated by an Optimizer
(e.g. used to implement norm constraints or value constraints). The
function must take as input the unprojected variable and must return the
projected variable (which must have the same shape). Constraints are not
safe to use when doing asynchronous distributed training.
|
Returns |
make_normal_fn
|
Python callable which creates a tfd.Normal
using from args: dtype, shape, name, trainable, add_variable_fn .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[]]