tfp.sts.build_factored_surrogate_posterior_stateless
Stay organized with collections
Save and categorize content based on your preferences.
Returns stateless functions for building a variational posterior.
tfp.sts.build_factored_surrogate_posterior_stateless(
model, batch_shape=(), name=None
)
The surrogate posterior consists of independent Normal distributions for
each parameter with trainable loc
and scale
, transformed using the
parameter's bijector
to the appropriate support space for that parameter.
Args |
model
|
An instance of StructuralTimeSeries representing a
time-series model. This represents a joint distribution over
time-series and their parameters with batch shape [b1, ..., bN] .
|
batch_shape
|
Batch shape (Python tuple , list , or int ) of initial
states to optimize in parallel.
Default value: () . (i.e., just run a single optimization).
|
name
|
Python str name prefixed to ops created by this function.
Default value: None (i.e., 'build_factored_surrogate_posterior').
|
Returns |
init_fn
|
A function that takes in a stateless random seed and returns the
parameters of the variational posterior.
|
build_surrogate_posterior_fn
|
A function that takes in the parameters and
returns a surrogate posterior distribution.
|
Examples
Assume we've built a structural time-series model:
day_of_week = tfp.sts.Seasonal(
num_seasons=7,
observed_time_series=observed_time_series,
name='day_of_week')
local_linear_trend = tfp.sts.LocalLinearTrend(
observed_time_series=observed_time_series,
name='local_linear_trend')
model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],
observed_time_series=observed_time_series)
To (statelessly) fit the model to data, we construct init_fn
and
build_surrogate_fn
. init_fn
constructs an initial set of parameters
and build_surrogate_fn
is passed into
tfp.vi.fit_surrogate_posterior_stateless
to optimize a variational bound.
# This example only works in the JAX backend because it uses
# `optax` for stateless optimizers.
seed = tfp.random.sanitize_seed(jax.random.PRNGKey(0), salt='fit_stateless')
init_seed, fit_seed, sample_seed = tfp.random.split_seed(seed, n=3)
init_fn, build_surrogate_fn = (
tfp.sts.build_factored_surrogate_posterior_stateless(model=model))
initial_parameters = init_fn(init_seed)
jd = model.joint_distribution(observed_time_series)
final_parameters, loss_curve = tfp.vi.fit_surrogate_posterior_stateless(
target_log_prob_fn=jd.log_prob,
initial_parameters=initial_parameters,
build_surrogate_posterior_fn=build_surrogate_fn,
optimizer=optax.adam(1e-4),
num_steps=200,
seed=fit_seed)
surrogate_posterior = build_surrogate_fn(final_parameters)
posterior_samples = surrogate_posterior.sample(50, seed=sample_seed)
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.sts.build_factored_surrogate_posterior_stateless\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/sts/fitting.py#L185-L263) |\n\nReturns stateless functions for building a variational posterior. \n\n tfp.sts.build_factored_surrogate_posterior_stateless(\n model, batch_shape=(), name=None\n )\n\nThe surrogate posterior consists of independent Normal distributions for\neach parameter with trainable `loc` and `scale`, transformed using the\nparameter's `bijector` to the appropriate support space for that parameter.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `model` | An instance of `StructuralTimeSeries` representing a time-series model. This represents a joint distribution over time-series and their parameters with batch shape `[b1, ..., bN]`. |\n| `batch_shape` | Batch shape (Python `tuple`, `list`, or `int`) of initial states to optimize in parallel. Default value: `()`. (i.e., just run a single optimization). |\n| `name` | Python `str` name prefixed to ops created by this function. Default value: `None` (i.e., 'build_factored_surrogate_posterior'). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|--------------------------------|-----------------------------------------------------------------------------------------------------------|\n| `init_fn` | A function that takes in a stateless random seed and returns the parameters of the variational posterior. |\n| `build_surrogate_posterior_fn` | A function that takes in the parameters and returns a surrogate posterior distribution. |\n\n\u003cbr /\u003e\n\n### Examples\n\nAssume we've built a structural time-series model: \n\n day_of_week = tfp.sts.Seasonal(\n num_seasons=7,\n observed_time_series=observed_time_series,\n name='day_of_week')\n local_linear_trend = tfp.sts.LocalLinearTrend(\n observed_time_series=observed_time_series,\n name='local_linear_trend')\n model = tfp.sts.Sum(components=[day_of_week, local_linear_trend],\n observed_time_series=observed_time_series)\n\nTo (statelessly) fit the model to data, we construct `init_fn` and\n`build_surrogate_fn`. `init_fn` constructs an initial set of parameters\nand `build_surrogate_fn` is passed into\n[`tfp.vi.fit_surrogate_posterior_stateless`](../../tfp/vi/fit_surrogate_posterior_stateless) to optimize a variational bound. \n\n # This example only works in the JAX backend because it uses\n # `optax` for stateless optimizers.\n seed = tfp.random.sanitize_seed(jax.random.PRNGKey(0), salt='fit_stateless')\n init_seed, fit_seed, sample_seed = tfp.random.split_seed(seed, n=3)\n init_fn, build_surrogate_fn = (\n tfp.sts.build_factored_surrogate_posterior_stateless(model=model))\n initial_parameters = init_fn(init_seed)\n jd = model.joint_distribution(observed_time_series)\n final_parameters, loss_curve = tfp.vi.fit_surrogate_posterior_stateless(\n target_log_prob_fn=jd.log_prob,\n initial_parameters=initial_parameters,\n build_surrogate_posterior_fn=build_surrogate_fn,\n optimizer=optax.adam(1e-4),\n num_steps=200,\n seed=fit_seed)\n surrogate_posterior = build_surrogate_fn(final_parameters)\n posterior_samples = surrogate_posterior.sample(50, seed=sample_seed)"]]