|  View source on GitHub | 
Builds a variational posterior by linearly transforming base distributions.
tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution_stateless(
    base_distribution,
    operators='diag',
    bijector=None,
    initial_unconstrained_loc_fn=_sample_uniform_initial_loc,
    validate_args=False,
    name=None
)
This function builds a surrogate posterior by applying a trainable
transformation to a base distribution (typically a tfd.JointDistribution) or
nested structure of base distributions, and constraining the samples with
bijector. Note that the distributions must have event shapes corresponding
to the pretransformed surrogate posterior -- that is, if bijector contains
a shape-changing bijector, then the corresponding base distribution event
shape is the inverse event shape of the bijector applied to the desired
surrogate posterior shape. The surrogate posterior is constucted as follows:
- Flatten the base distribution event shapes to vectors, and pack the base
distributions into a tfd.JointDistribution.
- Apply a trainable blockwise LinearOperator bijector to the joint base distribution.
- Apply the constraining bijectors and return the resulting trainable
tfd.TransformedDistributioninstance.
| Args | |
|---|---|
| base_distribution | tfd.Distributioninstance (typically atfd.JointDistribution), or a nested structure oftfd.Distributioninstances. | 
| operators | Either a string or a list/tuple containing LinearOperatorsubclasses,LinearOperatorinstances, or callables returningLinearOperatorinstances. Supported string values are "diag" (to create
a mean-field surrogate posterior) and "tril" (to create a full-covariance
surrogate posterior). A list/tuple may be passed to induce other
posterior covariance structures. If the list is flat, atf.linalg.LinearOperatorBlockDiaginstance will be created and applied
to the base distribution. Otherwise the list must be singly-nested and
have a first element of length 1, second element of length 2, etc.; the
elements of the outer list are interpreted as rows of a lower-triangular
block structure, and atf.linalg.LinearOperatorBlockLowerTriangularinstance is created. For complete documentation and examples, seetfp.experimental.vi.util.build_trainable_linear_operator_block, which
receives theoperatorsarg if it is list-like.
Default value:"diag". | 
| bijector | tfb.Bijectorinstance, or nested structure oftfb.Bijectorinstances, that maps (nested) values in R^n to the support of the
posterior. (This can be theexperimental_default_event_space_bijectorof
the distribution over the prior latent variables.)
Default value:None(i.e., the posterior is over R^n). | 
| initial_unconstrained_loc_fn | Optional Python callablewith signatureinitial_loc = initial_unconstrained_loc_fn(shape, dtype, seed)used to
sample real-valued initializations for the unconstrained location of
each variable.
Default value:functools.partial(tf.random.stateless_uniform,
  minval=-2., maxval=2., dtype=tf.float32). | 
| validate_args | Python bool. Whether to validate input with asserts. This
imposes a runtime cost. Ifvalidate_argsisFalse, and the inputs are
invalid, correct behavior is not guaranteed.
Default value:False. | 
| name | Python strname prefixed to ops created by this function.
Default value:None(i.e.,
'build_affine_surrogate_posterior_from_base_distribution'). | 
| Returns | |
|---|---|
| init_fn | Python callable with signature initial_parameters = init_fn(seed). | 
| apply_fn | Python callable with signature instance = apply_fn(*parameters). | 
Examples
tfd = tfp.distributions
tfb = tfp.bijectors
# Fit a multivariate Normal surrogate posterior on the Eight Schools model
# [1].
treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
def model_fn():
  avg_effect = yield tfd.Normal(loc=0., scale=10., name='avg_effect')
  log_stddev = yield tfd.Normal(loc=5., scale=1., name='log_stddev')
  school_effects = yield tfd.Sample(
      tfd.Normal(loc=avg_effect, scale=tf.exp(log_stddev)),
      sample_shape=[8],
      name='school_effects')
  treatment_effects = yield tfd.Independent(
      tfd.Normal(loc=school_effects, scale=treatment_stddevs),
      reinterpreted_batch_ndims=1,
      name='treatment_effects')
model = tfd.JointDistributionCoroutineAutoBatched(model_fn)
# Pin the observed values in the model.
target_model = model.experimental_pin(treatment_effects=treatment_effects)
# Define a lower triangular structure of `LinearOperator` subclasses that
# models full covariance among latent variables except for the 8 dimensions
# of `school_effect`, which are modeled as independent (using
# `LinearOperatorDiag`).
operators = [
  [tf.linalg.LinearOperatorLowerTriangular],
  [tf.linalg.LinearOperatorFullMatrix, LinearOperatorLowerTriangular],
  [tf.linalg.LinearOperatorFullMatrix, LinearOperatorFullMatrix,
   tf.linalg.LinearOperatorDiag]]
# Constrain the posterior values to the support of the prior.
bijector = target_model.experimental_default_event_space_bijector()
# Build a full-covariance surrogate posterior.
surrogate_posterior = (
  tfp.experimental.vi.build_affine_surrogate_posterior_from_base_distribution(
      base_distribution=base_distribution,
      operators=operators,
      bijector=bijector))
# Fit the model.
losses = tfp.vi.fit_surrogate_posterior(
    target_model.unnormalized_log_prob,
    surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)
References
[1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and Donald Rubin. Bayesian Data Analysis, Third Edition. Chapman and Hall/CRC, 2013.