(We omitted stop_gradient for brevity. See implementation for more details.)
The Avg{h[j;i] : j} term is a kind of "swap-out average" where the i-th
element has been replaced by the leave-i-out Geometric-average.
This implementation prefers numerical precision over efficiency, i.e.,
O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape)).
(The constant may be fairly large, perhaps around 12.)
Args
f
Python callable representing a Csiszar-function in log-space.
p_log_prob
Python callable representing the natural-log of the
probability under distribution p. (In variational inference p is the
joint distribution.)
q
tf.Distribution-like instance; must implement: sample(n, seed), and
log_prob(x). (In variational inference q is the approximate posterior
distribution.)
num_draws
Integer scalar number of draws used to approximate the
f-Divergence expectation.
num_batch_draws
Integer scalar number of draws used to approximate the
f-Divergence expectation.
Python str name prefixed to Ops created by this function.
Returns
vimco
The Csiszar f-Divergence generalized VIMCO objective.
Raises
ValueError
if num_draws < 2.
References
[1]: Andriy Mnih and Danilo Rezende. Variational Inference for Monte Carlo
objectives. In International Conference on Machine Learning, 2016.
https://arxiv.org/abs/1602.06725
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.substrates.jax.vi.csiszar_vimco\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/substrates/jax/vi/csiszar_divergence.py#L1179-L1284) |\n\nUse VIMCO to lower the variance of gradient\\[csiszar_function(log(Avg(u))\\].\n\n#### View aliases\n\n\n**Main aliases**\n\n[`tfp.experimental.substrates.jax.vi.csiszar_vimco`](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/vi/csiszar_vimco)\n\n\u003cbr /\u003e\n\n tfp.substrates.jax.vi.csiszar_vimco(\n f, p_log_prob, q, num_draws, num_batch_draws=1, seed=None, name=None\n )\n\nThis function generalizes VIMCO \\[(Mnih and Rezende, 2016)\\]\\[1\\] to Csiszar\nf-Divergences.\n| **Note:** if `q.reparameterization_type = tfd.FULLY_REPARAMETERIZED`, consider using `monte_carlo_variational_loss`.\n\n#### The VIMCO loss is:\n\n vimco = f(log(Avg{u[i] : i=0,...,m-1}))\n where,\n logu[i] = log( p(x, h[i]) / q(h[i] | x) )\n h[i] iid~ q(H | x)\n\nInterestingly, the VIMCO gradient is not the naive gradient of `vimco`.\nRather, it is characterized by: \n\n grad[vimco] - variance_reducing_term\n where,\n variance_reducing_term = Sum{ grad[log q(h[i] | x)] *\n (vimco - f(log Avg{h[j;i] : j=0,...,m-1}))\n : i=0, ..., m-1 }\n h[j;i] = { u[j] j!=i\n { GeometricAverage{ u[k] : k!=i} j==i\n\n(We omitted `stop_gradient` for brevity. See implementation for more details.)\n\nThe `Avg{h[j;i] : j}` term is a kind of \"swap-out average\" where the `i`-th\nelement has been replaced by the leave-`i`-out Geometric-average.\n\nThis implementation prefers numerical precision over efficiency, i.e.,\n`O(num_draws * num_batch_draws * prod(batch_shape) * prod(event_shape))`.\n(The constant may be fairly large, perhaps around 12.)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `f` | Python `callable` representing a Csiszar-function in log-space. |\n| `p_log_prob` | Python `callable` representing the natural-log of the probability under distribution `p`. (In variational inference `p` is the joint distribution.) |\n| `q` | `tf.Distribution`-like instance; must implement: `sample(n, seed)`, and `log_prob(x)`. (In variational inference `q` is the approximate posterior distribution.) |\n| `num_draws` | Integer scalar number of draws used to approximate the f-Divergence expectation. |\n| `num_batch_draws` | Integer scalar number of draws used to approximate the f-Divergence expectation. |\n| `seed` | PRNG seed for `q.sample`; see [`tfp.random.sanitize_seed`](../../../../tfp/random/sanitize_seed) for details. |\n| `name` | Python `str` name prefixed to Ops created by this function. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---------|-------------------------------------------------------|\n| `vimco` | The Csiszar f-Divergence generalized VIMCO objective. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|---------------------|\n| `ValueError` | if `num_draws \u003c 2`. |\n\n\u003cbr /\u003e\n\n#### References\n\n\\[1\\]: Andriy Mnih and Danilo Rezende. Variational Inference for Monte Carlo\nobjectives. In *International Conference on Machine Learning* , 2016.\n\u003chttps://arxiv.org/abs/1602.06725\u003e"]]