tfp.mcmc.potential_scale_reduction
Stay organized with collections
Save and categorize content based on your preferences.
Gelman and Rubin (1992)'s potential scale reduction for chain convergence.
tfp.mcmc.potential_scale_reduction(
chains_states,
independent_chain_ndims=1,
split_chains=False,
validate_args=False,
name=None
)
Used in the notebooks
Given N > 1
states from each of C > 1
independent chains, the potential
scale reduction factor, commonly referred to as R-hat, measures convergence of
the chains (to the same target) by testing for equality of means.
Specifically, R-hat measures the degree to which variance (of the means)
between chains exceeds what one would expect if the chains were identically
distributed. See [Gelman and Rubin (1992)][1]; [Brooks and Gelman (1998)][2].
Some guidelines:
- The initial state of the chains should be drawn from a distribution
overdispersed with respect to the target.
- If all chains converge to the target, then as
N --> infinity
, R-hat --> 1.
Before that, R-hat > 1 (except in pathological cases, e.g. if the chain
paths were identical).
- The above holds for any number of chains
C > 1
. Increasing C
does
improve effectiveness of the diagnostic.
- Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of
course this is problem-dependent. See [Brooks and Gelman (1998)][2].
- R-hat only measures non-convergence of the mean. If higher moments, or
other statistics are desired, a different diagnostic should be used. See
[Brooks and Gelman (1998)][2].
Args |
chains_states
|
Tensor or Python structure of Tensor s representing the
states of a Markov Chain at each result step. The ith state is
assumed to have shape [Ni, Ci1, Ci2,...,CiD] + A .
Dimension 0 indexes the Ni > 1 result steps of the Markov Chain.
Dimensions 1 through D index the Ci1 x ... x CiD independent
chains to be tested for convergence to the same target.
The remaining dimensions, A , can have any shape (even empty).
|
independent_chain_ndims
|
Integer type Tensor with value >= 1 giving the
number of dimensions, from dim = 1 to dim = D , holding independent
chain results to be tested for convergence.
|
split_chains
|
Python bool . If True , divide samples from each chain into
first and second halves, treating these as separate chains. This makes
R-hat more robust to non-stationary chains, and is recommended in [3].
|
validate_args
|
Whether to add runtime checks of argument validity. If False,
and arguments are incorrect, correct behavior is not guaranteed.
|
name
|
String name to prepend to created tf. Default:
potential_scale_reduction .
|
Returns |
Tensor structure parallel to chains_states representing the
R-hat statistic for the state(s). Same dtype as state , and
shape equal to state.shape[1 + independent_chain_ndims:] .
|
Raises |
ValueError
|
If independent_chain_ndims < 1 .
|
Examples
Diagnosing convergence by monitoring 10 chains that each attempt to
sample from a 2-variate normal.
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])
# Get 10 (2x) overdispersed initial states.
initial_state = target.sample(10) * 2.
==> (10, 2)
# Get 1000 samples from the 10 independent chains.
chains_states = tfp.mcmc.sample_chain(
num_burnin_steps=200,
num_results=1000,
current_state=initial_state,
trace_fn=None,
kernel=tfp.mcmc.HamiltonianMonteCarlo(
target_log_prob_fn=target.log_prob,
step_size=0.05,
num_leapfrog_steps=20))
chains_states.shape
==> (1000, 10, 2)
rhat = tfp.mcmc.diagnostic.potential_scale_reduction(
chains_states, independent_chain_ndims=1)
# The second dimension needed a longer burn-in.
rhat.eval()
==> [1.05, 1.3]
To see why R-hat is reasonable, let X
be a random variable drawn uniformly
from the combined states (combined over all chains). Then, in the limit
N, C --> infinity
, with E
, Var
denoting expectation and variance,
R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].
Using the law of total variance, the numerator is the variance of the combined
states, and the denominator is the total variance minus the variance of the
the individual chain means. If the chains are all drawing from the same
distribution, they will have the same mean, and thus the ratio should be one.
References
[1]: Stephen P. Brooks and Andrew Gelman. General Methods for Monitoring
Convergence of Iterative Simulations. Journal of Computational and
Graphical Statistics, 7(4), 1998.
[2]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation
Using Multiple Sequences. Statistical Science, 7(4):457-472, 1992.
[3]: Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter, Paul-Christian
Bürkner. Rank-normalization, folding, and localization: An improved R-hat
for assessing convergence of MCMC, 2021. Bayesian analysis,
16(2):667-718.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.mcmc.potential_scale_reduction\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/mcmc/diagnostic.py#L339-L473) |\n\nGelman and Rubin (1992)'s potential scale reduction for chain convergence. \n\n tfp.mcmc.potential_scale_reduction(\n chains_states,\n independent_chain_ndims=1,\n split_chains=False,\n validate_args=False,\n name=None\n )\n\n### Used in the notebooks\n\n| Used in the tutorials |\n|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| - [Multilevel Modeling Primer in TensorFlow Probability](https://www.tensorflow.org/probability/examples/Multilevel_Modeling_Primer) - [TensorFlow Probability Case Study: Covariance Estimation](https://www.tensorflow.org/probability/examples/TensorFlow_Probability_Case_Study_Covariance_Estimation) - [A Tour of TensorFlow Probability](https://www.tensorflow.org/probability/examples/A_Tour_of_TensorFlow_Probability) |\n\nGiven `N \u003e 1` states from each of `C \u003e 1` independent chains, the potential\nscale reduction factor, commonly referred to as R-hat, measures convergence of\nthe chains (to the same target) by testing for equality of means.\nSpecifically, R-hat measures the degree to which variance (of the means)\nbetween chains exceeds what one would expect if the chains were identically\ndistributed. See \\[Gelman and Rubin (1992)\\]\\[1\\]; \\[Brooks and Gelman (1998)\\]\\[2\\].\n\n#### Some guidelines:\n\n- The initial state of the chains should be drawn from a distribution overdispersed with respect to the target.\n- If all chains converge to the target, then as `N --\u003e infinity`, R-hat --\\\u003e 1. Before that, R-hat \\\u003e 1 (except in pathological cases, e.g. if the chain paths were identical).\n- The above holds for any number of chains `C \u003e 1`. Increasing `C` does improve effectiveness of the diagnostic.\n- Sometimes, R-hat \\\u003c 1.2 is used to indicate approximate convergence, but of course this is problem-dependent. See \\[Brooks and Gelman (1998)\\]\\[2\\].\n- R-hat only measures non-convergence of the mean. If higher moments, or other statistics are desired, a different diagnostic should be used. See \\[Brooks and Gelman (1998)\\]\\[2\\].\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `chains_states` | `Tensor` or Python structure of `Tensor`s representing the states of a Markov Chain at each result step. The `ith` state is assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`. Dimension `0` indexes the `Ni \u003e 1` result steps of the Markov Chain. Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent chains to be tested for convergence to the same target. The remaining dimensions, `A`, can have any shape (even empty). |\n| `independent_chain_ndims` | Integer type `Tensor` with value `\u003e= 1` giving the number of dimensions, from `dim = 1` to `dim = D`, holding independent chain results to be tested for convergence. |\n| `split_chains` | Python `bool`. If `True`, divide samples from each chain into first and second halves, treating these as separate chains. This makes R-hat more robust to non-stationary chains, and is recommended in \\[3\\]. |\n| `validate_args` | Whether to add runtime checks of argument validity. If False, and arguments are incorrect, correct behavior is not guaranteed. |\n| `name` | `String` name to prepend to created tf. Default: `potential_scale_reduction`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| `Tensor` structure parallel to `chains_states` representing the R-hat statistic for the state(s). Same `dtype` as `state`, and shape equal to `state.shape[1 + independent_chain_ndims:]`. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|-----------------------------------|\n| `ValueError` | If `independent_chain_ndims \u003c 1`. |\n\n\u003cbr /\u003e\n\n#### Examples\n\nDiagnosing convergence by monitoring 10 chains that each attempt to\nsample from a 2-variate normal. \n\n import tensorflow as tf\n import tensorflow_probability as tfp\n tfd = tfp.distributions\n\n target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.])\n\n # Get 10 (2x) overdispersed initial states.\n initial_state = target.sample(10) * 2.\n ==\u003e (10, 2)\n\n # Get 1000 samples from the 10 independent chains.\n chains_states = tfp.mcmc.sample_chain(\n num_burnin_steps=200,\n num_results=1000,\n current_state=initial_state,\n trace_fn=None,\n kernel=tfp.mcmc.HamiltonianMonteCarlo(\n target_log_prob_fn=target.log_prob,\n step_size=0.05,\n num_leapfrog_steps=20))\n chains_states.shape\n ==\u003e (1000, 10, 2)\n\n rhat = tfp.mcmc.diagnostic.potential_scale_reduction(\n chains_states, independent_chain_ndims=1)\n\n # The second dimension needed a longer burn-in.\n rhat.eval()\n ==\u003e [1.05, 1.3]\n\nTo see why R-hat is reasonable, let `X` be a random variable drawn uniformly\nfrom the combined states (combined over all chains). Then, in the limit\n`N, C --\u003e infinity`, with `E`, `Var` denoting expectation and variance,\n\n`R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].`\n\nUsing the law of total variance, the numerator is the variance of the combined\nstates, and the denominator is the total variance minus the variance of the\nthe individual chain means. If the chains are all drawing from the same\ndistribution, they will have the same mean, and thus the ratio should be one.\n\n#### References\n\n\\[1\\]: Stephen P. Brooks and Andrew Gelman. General Methods for Monitoring\nConvergence of Iterative Simulations. *Journal of Computational and\nGraphical Statistics*, 7(4), 1998.\n\n\\[2\\]: Andrew Gelman and Donald B. Rubin. Inference from Iterative Simulation\nUsing Multiple Sequences. *Statistical Science*, 7(4):457-472, 1992.\n\n\\[3\\]: Aki Vehtari, Andrew Gelman, Daniel Simpson, Bob Carpenter, Paul-Christian\nBürkner. Rank-normalization, folding, and localization: An improved R-hat\nfor assessing convergence of MCMC, 2021. Bayesian analysis,\n16(2):667-718."]]