tfp.experimental.mcmc.snaper_criterion
Stay organized with collections
Save and categorize content based on your preferences.
The SNAPER criterion from [1].
tfp.experimental.mcmc.snaper_criterion(
previous_state,
proposed_state,
accept_prob,
trajectory_length,
direction,
state_mean=None,
state_mean_weight=0.0,
validate_args=False,
experimental_shard_axis_names=None,
experimental_reduce_chain_axis_names=None
)
SNAPER stands for Squared Norm Along Principal component ESJD Rate:
SNAPER = E[(((x' - E[x'])^T p)**2 - ((x' - E[x])^T p)**2)**2 /
trajectory_length],
where x
is the previous chain state, x'
is the next chain state, and p
is a unit vector (the direction
argument). Both expectations are with
respect to the chain's stationary distribution. In practice, the inner
expectation is replaced by the empirical mean across chains, so computing this
criterion requires that at least 2 chains are present unless state_mean
and
state_mean_weight
are set. The outer expectation is computed by the caller
(e.g. in the GradientBasedTrajectoryLengthAdaptation
kernel).
This can be thought of as the standard expected squared jump distance (ESJD)
criterion, except that the jump distance is computed in the space of squared
projections onto a vector.
The direction
vector is typically chosen to be an approximation to the first
principal component of the state covariance matrix.
state_mean
and state_mean_weight
can be used to supplement the empirical
means as follows:
E[x] ≈ (1 - state_mean_weight) * x.mean() + state_mean_weight * state_mean.
Args |
previous_state
|
(Possibly nested) floating point Tensor . The previous
state of the HMC chain.
|
proposed_state
|
(Possibly nested) floating point Tensor . The proposed
state of the HMC chain.
|
accept_prob
|
Floating Tensor . Probability of acceping the proposed state.
|
trajectory_length
|
Floating Tensor . Mean trajectory length (not used in
this criterion).
|
direction
|
(Possibly nested) floating point Tensor . A unit vector onto
which the centered state should be projected before computing ESJD.
Typically this chosen to be an approximation to the first principal
component of the state covariance matrix.
|
state_mean
|
Optional (Possibly nested) floating point Tensor . The
estimated state mean.
|
state_mean_weight
|
Floating point Tensor . The weight of the state_mean .
|
validate_args
|
Whether to perform non-static argument validation.
|
experimental_shard_axis_names
|
A structure of string names indicating how
members of the state are sharded.
|
experimental_reduce_chain_axis_names
|
A string or list of string names
indicating which named chain axes to reduce over when computing the
criterion.
|
Returns |
snaper
|
The value of the SNAPER criterion.
|
References
[1]: Sountsov, P. & Hoffman, M. (2021). Focusing on Difficult Directions for
Learning HMC Trajectory Lengths. <https://arxiv.org/abs/2110.11576>
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.experimental.mcmc.snaper_criterion\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/experimental/mcmc/gradient_based_trajectory_length_adaptation.py#L330-L479) |\n\nThe SNAPER criterion from \\[1\\]. \n\n tfp.experimental.mcmc.snaper_criterion(\n previous_state,\n proposed_state,\n accept_prob,\n trajectory_length,\n direction,\n state_mean=None,\n state_mean_weight=0.0,\n validate_args=False,\n experimental_shard_axis_names=None,\n experimental_reduce_chain_axis_names=None\n )\n\nSNAPER stands for Squared Norm Along Principal component ESJD Rate: \n\n SNAPER = E[(((x' - E[x'])^T p)**2 - ((x' - E[x])^T p)**2)**2 /\n trajectory_length],\n\nwhere `x` is the previous chain state, `x'` is the next chain state, and `p`\nis a unit vector (the `direction` argument). Both expectations are with\nrespect to the chain's stationary distribution. In practice, the inner\nexpectation is replaced by the empirical mean across chains, so computing this\ncriterion requires that at least 2 chains are present unless `state_mean` and\n`state_mean_weight` are set. The outer expectation is computed by the caller\n(e.g. in the `GradientBasedTrajectoryLengthAdaptation` kernel).\n\nThis can be thought of as the standard expected squared jump distance (ESJD)\ncriterion, except that the jump distance is computed in the space of squared\nprojections onto a vector.\n\nThe `direction` vector is typically chosen to be an approximation to the first\nprincipal component of the state covariance matrix.\n\n`state_mean` and `state_mean_weight` can be used to supplement the empirical\nmeans as follows: \n\n E[x] ≈ (1 - state_mean_weight) * x.mean() + state_mean_weight * state_mean.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `previous_state` | (Possibly nested) floating point `Tensor`. The previous state of the HMC chain. |\n| `proposed_state` | (Possibly nested) floating point `Tensor`. The proposed state of the HMC chain. |\n| `accept_prob` | Floating `Tensor`. Probability of acceping the proposed state. |\n| `trajectory_length` | Floating `Tensor`. Mean trajectory length (not used in this criterion). |\n| `direction` | (Possibly nested) floating point `Tensor`. A unit vector onto which the centered state should be projected before computing ESJD. Typically this chosen to be an approximation to the first principal component of the state covariance matrix. |\n| `state_mean` | Optional (Possibly nested) floating point `Tensor`. The estimated state mean. |\n| `state_mean_weight` | Floating point `Tensor`. The weight of the `state_mean`. |\n| `validate_args` | Whether to perform non-static argument validation. |\n| `experimental_shard_axis_names` | A structure of string names indicating how members of the state are sharded. |\n| `experimental_reduce_chain_axis_names` | A string or list of string names indicating which named chain axes to reduce over when computing the criterion. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|----------|------------------------------------|\n| `snaper` | The value of the SNAPER criterion. |\n\n\u003cbr /\u003e\n\n#### References\n\n\\[1\\]: Sountsov, P. \\& Hoffman, M. (2021). Focusing on Difficult Directions for\nLearning HMC Trajectory Lengths. \\\u003c\u003chttps://arxiv.org/abs/2110.11576\u003e\\\u003e"]]