tf.nn.weighted_cross_entropy_with_logits
Stay organized with collections
Save and categorize content based on your preferences.
Computes a weighted cross entropy.
tf.nn.weighted_cross_entropy_with_logits(
labels, logits, pos_weight, name=None
)
This is like sigmoid_cross_entropy_with_logits()
except that pos_weight
,
allows one to trade off recall and precision by up- or down-weighting the
cost of a positive error relative to a negative error.
The usual cross-entropy cost is defined as:
labels * -log(sigmoid(logits)) +
(1 - labels) * -log(1 - sigmoid(logits))
A value pos_weight > 1
decreases the false negative count, hence increasing
the recall.
Conversely setting pos_weight < 1
decreases the false positive count and
increases the precision.
This can be seen from the fact that pos_weight
is introduced as a
multiplicative coefficient for the positive labels term
in the loss expression:
labels * -log(sigmoid(logits)) * pos_weight +
(1 - labels) * -log(1 - sigmoid(logits))
For brevity, let x = logits
, z = labels
, q = pos_weight
.
The loss is:
qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
Setting l = (1 + (q - 1) * z)
, to ensure stability and avoid overflow,
the implementation uses
(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))
logits
and labels
must have the same type and shape.
labels = tf.constant([1., 0.5, 0.])
logits = tf.constant([1.5, -0.1, -10.])
tf.nn.weighted_cross_entropy_with_logits(
labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy()
array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32)
tf.nn.weighted_cross_entropy_with_logits(
labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy()
array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32)
Args |
labels
|
A Tensor of the same type and shape as logits , with values
between 0 and 1 inclusive.
|
logits
|
A Tensor of type float32 or float64 , any real numbers.
|
pos_weight
|
A coefficient to use on the positive examples, typically a
scalar but otherwise broadcastable to the shape of logits . Its value
should be non-negative.
|
name
|
A name for the operation (optional).
|
Returns |
A Tensor of the same shape as logits with the componentwise
weighted logistic losses.
|
Raises |
ValueError
|
If logits and labels do not have the same shape.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.nn.weighted_cross_entropy_with_logits\n\n\u003cbr /\u003e\n\n|---------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/ops/nn_impl.py#L249-L338) |\n\nComputes a weighted cross entropy. \n\n tf.nn.weighted_cross_entropy_with_logits(\n labels, logits, pos_weight, name=None\n )\n\nThis is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,\nallows one to trade off recall and precision by up- or down-weighting the\ncost of a positive error relative to a negative error.\n\nThe usual cross-entropy cost is defined as: \n\n labels * -log(sigmoid(logits)) +\n (1 - labels) * -log(1 - sigmoid(logits))\n\nA value `pos_weight \u003e 1` decreases the false negative count, hence increasing\nthe recall.\nConversely setting `pos_weight \u003c 1` decreases the false positive count and\nincreases the precision.\nThis can be seen from the fact that `pos_weight` is introduced as a\nmultiplicative coefficient for the positive labels term\nin the loss expression: \n\n labels * -log(sigmoid(logits)) * pos_weight +\n (1 - labels) * -log(1 - sigmoid(logits))\n\nFor brevity, let `x = logits`, `z = labels`, `q = pos_weight`.\nThe loss is: \n\n qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))\n = qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))\n = qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))\n = qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))\n = (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))\n = (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))\n\nSetting `l = (1 + (q - 1) * z)`, to ensure stability and avoid overflow,\nthe implementation uses \n\n (1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))\n\n`logits` and `labels` must have the same type and shape. \n\n labels = tf.constant([1., 0.5, 0.])\n logits = tf.constant([1.5, -0.1, -10.])\n tf.nn.weighted_cross_entropy_with_logits(\n labels=labels, logits=logits, pos_weight=tf.constant(1.5)).numpy()\n array([3.0211994e-01, 8.8049585e-01, 4.5776367e-05], dtype=float32)\n tf.nn.weighted_cross_entropy_with_logits(\n labels=labels, logits=logits, pos_weight=tf.constant(0.5)).numpy()\n array([1.00706644e-01, 5.08297503e-01, 4.57763672e-05], dtype=float32)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `labels` | A `Tensor` of the same type and shape as `logits`, with values between 0 and 1 inclusive. |\n| `logits` | A `Tensor` of type `float32` or `float64`, any real numbers. |\n| `pos_weight` | A coefficient to use on the positive examples, typically a scalar but otherwise broadcastable to the shape of `logits`. Its value should be non-negative. |\n| `name` | A name for the operation (optional). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A `Tensor` of the same shape as `logits` with the componentwise weighted logistic losses. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|------------------------------------------------------|\n| `ValueError` | If `logits` and `labels` do not have the same shape. |\n\n\u003cbr /\u003e"]]