tfp.vi.mutual_information.lower_bound_jensen_shannon
Stay organized with collections
Save and categorize content based on your preferences.
Lower bound on Jensen-Shannon (JS) divergence.
tfp.vi.mutual_information.lower_bound_jensen_shannon(
logu, joint_sample_mask=None, validate_args=False, name=None
)
This lower bound on JS divergence is proposed in
[Goodfellow et al. (2014)][1] and [Nowozin et al. (2016)][2].
When estimating lower bounds on mutual information, one can also use
different approaches for training the critic w.r.t. estimating
mutual information [(Poole et al., 2018)][3]. The JS lower bound is
used to train the critic with the standard lower bound on the
Jensen-Shannon divergence as used in GANs, and then evaluates the
critic using the NWJ lower bound on KL divergence, i.e. mutual information.
As Eq.7 and Eq.8 of [Nowozin et al. (2016)][2], the bound is given by
I_JS = E_p(x,y)[log( D(x,y) )] + E_p(x)p(y)[log( 1 - D(x,y) )]
where the first term is the expectation over the samples from joint
distribution (positive samples), and the second is for the samples
from marginal distributions (negative samples), with
D(x, y) = sigmoid(f(x, y)),
log(D(x, y)) = softplus(-f(x, y)).
f(x, y)
is a critic function that scores all pairs of samples.
Example:
X
, Y
are samples from a joint Gaussian distribution, with
correlation 0.8
and both of dimension 1
.
batch_size, rho, dim = 10000, 0.8, 1
y, eps = tf.split(
value=tf.random.normal(shape=(2 * batch_size, dim), seed=7),
num_or_size_splits=2, axis=0)
mean, conditional_stddev = rho * y, tf.sqrt(1. - tf.square(rho))
x = mean + conditional_stddev * eps
# Scores/unnormalized likelihood of pairs of samples `x[i], y[j]`
# (For JS lower bound, the optimal critic is of the form `f(x, y) = 1 +
# log(p(x | y) / p(x))` [(Poole et al., 2018)][3].)
conditional_dist = tfd.MultivariateNormalDiag(
mean, scale_diag=conditional_stddev * tf.ones((batch_size, dim)))
conditional_scores = conditional_dist.log_prob(y[:, tf.newaxis, :])
marginal_dist = tfd.MultivariateNormalDiag(tf.zeros(dim), tf.ones(dim))
marginal_scores = marginal_dist.log_prob(y)[:, tf.newaxis]
scores = 1 + conditional_scores - marginal_scores
# Mask for joint samples in the score tensor
# (The `scores` has its shape [x_batch_size, y_batch_size], i.e.
# `scores[i, j] = f(x[i], y[j]) = log p(x[i] | y[j])`.)
joint_sample_mask = tf.eye(batch_size, dtype=bool)
# Lower bound on Jensen Shannon divergence
lower_bound_jensen_shannon(logu=scores, joint_sample_mask=joint_sample_mask)
Args |
logu
|
float -like Tensor of size [batch_size_1, batch_size_2]
representing critic scores (scores) for pairs of points (x, y) with
logu[i, j] = f(x[i], y[j]) .
|
joint_sample_mask
|
bool -like Tensor of the same size as logu
masking the positive samples by True , i.e. samples from joint
distribution p(x, y) .
Default value: None . By default, an identity matrix is constructed as
the mask.
|
validate_args
|
Python bool , default False . Whether to validate input
with asserts. If validate_args is False , and the inputs are invalid,
correct behavior is not guaranteed.
|
name
|
Python str name prefixed to Ops created by this function.
Default value: None (i.e., 'lower_bound_jensen_shannon').
|
Returns |
lower_bound
|
float -like scalar for lower bound on JS divergence.
|
References:
[1]: Ian J. Goodfellow, et al. Generative Adversarial Nets. In
Conference on Neural Information Processing Systems, 2014.
https://arxiv.org/abs/1406.2661
[2]: Sebastian Nowozin, Botond Cseke, Ryota Tomioka. f-GAN: Training
Generative Neural Samplers using Variational Divergence Minimization.
In Conference on Neural Information Processing Systems, 2016.
https://arxiv.org/abs/1606.00709
[3]: Ben Poole, Sherjil Ozair, Aaron van den Oord, Alexander A. Alemi,
George Tucker. On Variational Bounds of Mutual Information. In
International Conference on Machine Learning, 2019.
https://arxiv.org/abs/1905.06922
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-11-21 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-11-21 UTC."],[],[],null,["# tfp.vi.mutual_information.lower_bound_jensen_shannon\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/probability/blob/v0.23.0/tensorflow_probability/python/vi/mutual_information.py#L345-L462) |\n\nLower bound on Jensen-Shannon (JS) divergence. \n\n tfp.vi.mutual_information.lower_bound_jensen_shannon(\n logu, joint_sample_mask=None, validate_args=False, name=None\n )\n\nThis lower bound on JS divergence is proposed in\n\\[Goodfellow et al. (2014)\\]\\[1\\] and \\[Nowozin et al. (2016)\\]\\[2\\].\nWhen estimating lower bounds on mutual information, one can also use\ndifferent approaches for training the critic w.r.t. estimating\nmutual information \\[(Poole et al., 2018)\\]\\[3\\]. The JS lower bound is\nused to train the critic with the standard lower bound on the\nJensen-Shannon divergence as used in GANs, and then evaluates the\ncritic using the NWJ lower bound on KL divergence, i.e. mutual information.\nAs Eq.7 and Eq.8 of \\[Nowozin et al. (2016)\\]\\[2\\], the bound is given by \n\n I_JS = E_p(x,y)[log( D(x,y) )] + E_p(x)p(y)[log( 1 - D(x,y) )]\n\nwhere the first term is the expectation over the samples from joint\ndistribution (positive samples), and the second is for the samples\nfrom marginal distributions (negative samples), with \n\n D(x, y) = sigmoid(f(x, y)),\n log(D(x, y)) = softplus(-f(x, y)).\n\n`f(x, y)` is a critic function that scores all pairs of samples.\n\n#### Example:\n\n`X`, `Y` are samples from a joint Gaussian distribution, with\ncorrelation `0.8` and both of dimension `1`. \n\n batch_size, rho, dim = 10000, 0.8, 1\n y, eps = tf.split(\n value=tf.random.normal(shape=(2 * batch_size, dim), seed=7),\n num_or_size_splits=2, axis=0)\n mean, conditional_stddev = rho * y, tf.sqrt(1. - tf.square(rho))\n x = mean + conditional_stddev * eps\n\n # Scores/unnormalized likelihood of pairs of samples `x[i], y[j]`\n # (For JS lower bound, the optimal critic is of the form `f(x, y) = 1 +\n # log(p(x | y) / p(x))` [(Poole et al., 2018)][3].)\n conditional_dist = tfd.MultivariateNormalDiag(\n mean, scale_diag=conditional_stddev * tf.ones((batch_size, dim)))\n conditional_scores = conditional_dist.log_prob(y[:, tf.newaxis, :])\n marginal_dist = tfd.MultivariateNormalDiag(tf.zeros(dim), tf.ones(dim))\n marginal_scores = marginal_dist.log_prob(y)[:, tf.newaxis]\n scores = 1 + conditional_scores - marginal_scores\n\n # Mask for joint samples in the score tensor\n # (The `scores` has its shape [x_batch_size, y_batch_size], i.e.\n # `scores[i, j] = f(x[i], y[j]) = log p(x[i] | y[j])`.)\n joint_sample_mask = tf.eye(batch_size, dtype=bool)\n\n # Lower bound on Jensen Shannon divergence\n lower_bound_jensen_shannon(logu=scores, joint_sample_mask=joint_sample_mask)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `logu` | `float`-like `Tensor` of size `[batch_size_1, batch_size_2]` representing critic scores (scores) for pairs of points (x, y) with `logu[i, j] = f(x[i], y[j])`. |\n| `joint_sample_mask` | `bool`-like `Tensor` of the same size as `logu` masking the positive samples by `True`, i.e. samples from joint distribution `p(x, y)`. Default value: `None`. By default, an identity matrix is constructed as the mask. |\n| `validate_args` | Python `bool`, default `False`. Whether to validate input with asserts. If `validate_args` is `False`, and the inputs are invalid, correct behavior is not guaranteed. |\n| `name` | Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'lower_bound_jensen_shannon'). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---------------|---------------------------------------------------------|\n| `lower_bound` | `float`-like `scalar` for lower bound on JS divergence. |\n\n\u003cbr /\u003e\n\n#### References:\n\n\\[1\\]: Ian J. Goodfellow, et al. Generative Adversarial Nets. In\n*Conference on Neural Information Processing Systems* , 2014.\n\u003chttps://arxiv.org/abs/1406.2661\u003e\n\\[2\\]: Sebastian Nowozin, Botond Cseke, Ryota Tomioka. f-GAN: Training\nGenerative Neural Samplers using Variational Divergence Minimization.\nIn *Conference on Neural Information Processing Systems* , 2016.\n\u003chttps://arxiv.org/abs/1606.00709\u003e\n\\[3\\]: Ben Poole, Sherjil Ozair, Aaron van den Oord, Alexander A. Alemi,\nGeorge Tucker. On Variational Bounds of Mutual Information. In\n*International Conference on Machine Learning* , 2019.\n\u003chttps://arxiv.org/abs/1905.06922\u003e"]]