Runs one step of the RWM algorithm with symmetric proposal.
Inherits From: TransitionKernel
tfp.substrates.jax.mcmc.RandomWalkMetropolis(
    target_log_prob_fn,
    new_state_fn=None,
    experimental_shard_axis_names=None,
    name=None
)
Random Walk Metropolis is a gradient-free Markov chain Monte Carlo
(MCMC) algorithm. The algorithm involves a proposal generating step
proposal_state = current_state + perturb by a random
perturbation, followed by Metropolis-Hastings accept/reject step. For more
details see Section 2.1 of Roberts and Rosenthal (2004).
Current class implements RWM for normal and uniform proposals. Alternatively,
the user can supply any custom proposal generating function.
The function one_step can update multiple chains in parallel. It assumes
that all leftmost dimensions of current_state index independent chain states
(and are therefore updated independently). The output of
target_log_prob_fn(*current_state) should sum log-probabilities across all
event dimensions. Slices along the rightmost dimensions may have different
target distributions; for example, current_state[0, :] could have a
different target distribution from current_state[1, :]. These semantics
are governed by target_log_prob_fn(*current_state). (The number of
independent chains is tf.size(target_log_prob_fn(*current_state)).)
Examples:
Sampling from the Standard Normal Distribution.
import numpy as np
from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
dtype = np.float32
target = tfd.Normal(loc=dtype(0), scale=dtype(1))
samples = tfp.mcmc.sample_chain(
  num_results=1000,
  current_state=dtype(1),
  kernel=tfp.mcmc.RandomWalkMetropolis(target.log_prob),
  num_burnin_steps=500,
  trace_fn=None,
  seed=42)
sample_mean = tf.math.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(
    tf.math.reduce_mean(
        tf.math.squared_difference(samples, sample_mean),
        axis=0))
print('Estimated mean: {}'.format(sample_mean))
print('Estimated standard deviation: {}'.format(sample_std))
Sampling from a 2-D Normal Distribution.
import numpy as np
from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
dtype = np.float32
true_mean = dtype([0, 0])
true_cov = dtype([[1, 0.5],
                  [0.5, 1]])
num_results = 500
num_chains = 100
# Target distribution is defined through the Cholesky decomposition `L`:
L = tf.linalg.cholesky(true_cov)
target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=L)
# Initial state of the chain
init_state = np.ones([num_chains, 2], dtype=dtype)
# Run Random Walk Metropolis with normal proposal for `num_results`
# iterations for `num_chains` independent chains:
samples = tfp.mcmc.sample_chain(
    num_results=num_results,
    current_state=init_state,
    kernel=tfp.mcmc.RandomWalkMetropolis(target_log_prob_fn=target.log_prob),
    num_burnin_steps=200,
    num_steps_between_results=1,  # Thinning.
    trace_fn=None,
    seed=54)
sample_mean = tf.math.reduce_mean(samples, axis=0)
x = tf.squeeze(samples - sample_mean)
sample_cov = tf.matmul(tf.transpose(x, [1, 2, 0]),
                       tf.transpose(x, [1, 0, 2])) / num_results
mean_sample_mean = tf.math.reduce_mean(sample_mean)
mean_sample_cov = tf.math.reduce_mean(sample_cov, axis=0)
x = tf.reshape(sample_cov - mean_sample_cov, [num_chains, 2 * 2])
cov_sample_cov = tf.reshape(tf.matmul(x, x, transpose_a=True) / num_chains,
                            shape=[2 * 2, 2 * 2])
print('Estimated mean: {}'.format(mean_sample_mean))
print('Estimated avg covariance: {}'.format(mean_sample_cov))
print('Estimated covariance of covariance: {}'.format(cov_sample_cov))
Sampling from the Standard Normal Distribution using Cauchy proposal.
import numpy as np
from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
import tensorflow_probability as tfp; tfp = tfp.substrates.jax
tfd = tfp.distributions
dtype = np.float32
num_burnin_steps = 500
num_chain_results = 1000
def cauchy_new_state_fn(scale, dtype):
  cauchy = tfd.Cauchy(loc=dtype(0), scale=dtype(scale))
  def _fn(state_parts, seed):
    next_state_parts = []
    part_seeds = tfp.random.split_seed(
        seed, n=len(state_parts), salt='rwmcauchy')
    for sp, ps in zip(state_parts, part_seeds):
      next_state_parts.append(sp + cauchy.sample(
        sample_shape=sp.shape, seed=ps))
    return next_state_parts
  return _fn
target = tfd.Normal(loc=dtype(0), scale=dtype(1))
samples = tfp.mcmc.sample_chain(
    num_results=num_chain_results,
    num_burnin_steps=num_burnin_steps,
    current_state=dtype(1),
    kernel=tfp.mcmc.RandomWalkMetropolis(
        target.log_prob,
        new_state_fn=cauchy_new_state_fn(scale=0.5, dtype=dtype)),
    trace_fn=None,
    seed=42)
sample_mean = tf.math.reduce_mean(samples, axis=0)
sample_std = tf.sqrt(
    tf.math.reduce_mean(
        tf.math.squared_difference(samples, sample_mean),
        axis=0))
print('Estimated mean: {}'.format(sample_mean))
print('Estimated standard deviation: {}'.format(sample_std))
| Args | 
|---|
| target_log_prob_fn | Python callable which takes an argument like current_state(or*current_stateif it's a list) and returns its
(possibly unnormalized) log-density under the target distribution. | 
| new_state_fn | Python callable which takes a list of state parts and a
seed; returns a same-type listofTensors, each being a perturbation
of the input state parts. The perturbation distribution is assumed to be
a symmetric distribution centered at the input state part.
Default value:Nonewhich is mapped totfp.mcmc.random_walk_normal_fn(). | 
| experimental_shard_axis_names | A structure of string names indicating how
members of the state are sharded. | 
| name | Python strname prefixed to Ops created by this function.
Default value:None(i.e., 'rwm_kernel'). | 
| Raises | 
|---|
| ValueError | if there isn't one scaleor a list with same length ascurrent_state. | 
| Attributes | 
|---|
| experimental_shard_axis_names | The shard axis names for members of the state. | 
| is_calibrated | Returns Trueif Markov chain converges to specified distribution.TransitionKernels which are "uncalibrated" are often calibrated by
composing them with thetfp.mcmc.MetropolisHastingsTransitionKernel.
 | 
| name |  | 
| new_state_fn |  | 
| parameters | Return dictof__init__arguments and their values. | 
| target_log_prob_fn |  | 
Methods
bootstrap_results
View source
bootstrap_results(
    init_state
)
Creates initial previous_kernel_results using a supplied state.
copy
View source
copy(
    **override_parameter_kwargs
)
Non-destructively creates a deep copy of the kernel.
| Args | 
|---|
| **override_parameter_kwargs | Python String/value dictionaryof
initialization arguments to override with new values. | 
| Returns | 
|---|
| new_kernel | TransitionKernelobject of same type asself,
initialized with the union of self.parameters and
override_parameter_kwargs, with any shared keys overridden by the
value of override_parameter_kwargs, i.e.,dict(self.parameters, **override_parameters_kwargs). | 
experimental_with_shard_axes
View source
experimental_with_shard_axes(
    shard_axes
)
Returns a copy of the kernel with the provided shard axis names.
| Args | 
|---|
| shard_axis_names | a structure of strings indicating the shard axis names
for each component of this kernel's state. | 
| Returns | 
|---|
| A copy of the current kernel with the shard axis information. | 
one_step
View source
one_step(
    current_state, previous_kernel_results, seed=None
)
Runs one iteration of Random Walk Metropolis with normal proposal.
| Args | 
|---|
| current_state | Tensoror PythonlistofTensors representing the
current state(s) of the Markov chain(s). The firstrdimensions index
independent chains,r = tf.rank(target_log_prob_fn(*current_state)). | 
| previous_kernel_results | collections.namedtuplecontainingTensors
representing values from previous calls to this function (or from thebootstrap_resultsfunction.) | 
| seed | PRNG seed; see tfp.random.sanitize_seedfor details. | 
| Returns | 
|---|
| next_state | Tensor or Python list of Tensors representing the state(s)
of the Markov chain(s) after taking exactly one step. Has same type and
shape ascurrent_state. | 
| kernel_results | collections.namedtupleof internal calculations used to
advance the chain. | 
| Raises | 
|---|
| ValueError | if there isn't one scaleor a list with same length ascurrent_state. |