View source on GitHub
|
Returns stateless functions for building a variational posterior.
tfp.substrates.numpy.sts.build_factored_surrogate_posterior_stateless(
model, batch_shape=(), name=None
)
The surrogate posterior consists of independent Normal distributions for
each parameter with trainable loc and scale, transformed using the
parameter's bijector to the appropriate support space for that parameter.
Examples
Assume we've built a structural time-series model:
day_of_week = tfp.sts.Seasonal(
num_seasons=7,
observed_time_series=observed_time_series,
name='day_of_week')
local_linear_trend = tfp.sts.LocalLinearTrend(
observed_time_series=observed_time_series,
name='local_linear_trend')
model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
observed_time_series=observed_time_series)
To (statelessly) fit the model to data, we construct init_fn and
build_surrogate_fn. init_fn constructs an initial set of parameters
and build_surrogate_fn is passed into
tfp.vi.fit_surrogate_posterior_stateless to optimize a variational bound.
# This example only works in the JAX backend because it uses
# `optax` for stateless optimizers.
seed = tfp.random.sanitize_seed(jax.random.PRNGKey(0), salt='fit_stateless')
init_seed, fit_seed, sample_seed = tfp.random.split_seed(seed, n=3)
init_fn, build_surrogate_fn = (
tfp.sts.build_factored_surrogate_posterior_stateless(model=model))
initial_parameters = init_fn(init_seed)
jd = model.joint_distribution(observed_time_series)
final_parameters, loss_curve = tfp.vi.fit_surrogate_posterior_stateless(
target_log_prob_fn=jd.log_prob,
initial_parameters=initial_parameters,
build_surrogate_posterior_fn=build_surrogate_fn,
optimizer=optax.adam(1e-4),
num_steps=200,
seed=fit_seed)
surrogate_posterior = build_surrogate_fn(final_parameters)
posterior_samples = surrogate_posterior.sample(50, seed=sample_seed)
View source on GitHub