nsl.lib.jensen_shannon_divergence
Stay organized with collections
Save and categorize content based on your preferences.
Adds a Jensen-Shannon divergence to the training procedure.
nsl.lib.jensen_shannon_divergence(
labels,
predictions,
axis=None,
weights=1.0,
scope=None,
loss_collection=tf.compat.v1.GraphKeys.LOSSES,
reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS
)
For brevity, let P = labels
, Q = predictions
, KL(P||Q)
be the
Kullback-Leibler divergence as defined in the description of the
nsl.lib.kl_divergence
function.". The Jensen-Shannon divergence (JSD) is
M = (P + Q) / 2
JSD(P||Q) = KL(P||M) / 2 + KL(Q||M) / 2
This function assumes that predictions
and labels
are the values of a
multinomial distribution, i.e., each value is the probability of the
corresponding class.
For the usage of weights
and reduction
, please refer to tf.losses
.
Args |
labels
|
Tensor of type float32 or float64 , with shape [d1, ..., dN,
num_classes] , represents the target distribution.
|
predictions
|
Tensor of the same type and shape as labels , represents
the predicted distribution.
|
axis
|
The dimension along which the Jensen-Shannon divergence is computed.
The values of labels and predictions along axis should meet the
requirements of a multinomial distribution.
|
weights
|
(optional) Tensor whose rank is either 0, or the same as that of
labels , and must be broadcastable to labels (i.e., all dimensions must
be either 1 , or the same as the corresponding losses dimension).
|
scope
|
The scope for the operations performed in computing the loss.
|
loss_collection
|
Collection to which the loss will be added.
|
reduction
|
Type of reduction to apply to the loss.
|
Raises |
InvalidArgumentError
|
If labels or predictions don't meet the
requirements of a multinomial distribution.
|
ValueError
|
If axis is None , the shape of predictions doesn't match
that of labels , or if the shape of weights is invalid.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-01-12 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-01-12 UTC."],[],[]]