tfr.keras.losses.CoupledRankDistilLoss
Stay organized with collections
Save and categorize content based on your preferences.
Computes the Rank Distil loss between y_true
and y_pred
.
tfr.keras.losses.CoupledRankDistilLoss(
reduction: tf.losses.Reduction = tf.losses.Reduction.AUTO,
name: Optional[str] = None,
ragged: bool = False,
sample_size: int = 8,
topk: Optional[int] = None,
temperature: Optional[float] = 1.0
)
The Coupled-RankDistil loss (Reddi et al, 2021) is the
cross-entropy between k-Plackett's probability of logits (student) and labels
(teacher).
Standalone usage:
tf.random.set_seed(1)
y_true = [[0., 2., 1.], [1., 0., 2.]]
ln = tf.math.log
y_pred = [[0., ln(3.), ln(2.)], [0., ln(2.), ln(3.)]]
loss = tfr.keras.losses.CoupledRankDistilLoss(topk=2, sample_size=1)
loss(y_true, y_pred).numpy()
2.138333
Usage with the compile()
API:
model.compile(optimizer='sgd',
loss=tfr.keras.losses.CoupledRankDistilLoss())
Definition:
The k-Plackett's probability model is defined as:
\[
\mathcal{P}_k(\pi|s) = \frac{1}{(N-k)!} \\
\frac{\prod_{i=1}^k exp(s_{\pi(i)})}{\sum_{j=k}^N log(exp(s_{\pi(i)}))}.
\]
The Coupled-RankDistil loss is defined as:
\[
\mathcal{L}(y, s) = -\sum_{\pi} \mathcal{P}_k(\pi|y) log\mathcal{P}(\pi|s) \\
= \mathcal{E}_{\pi \sim \matcal{P}(.|y)} [-\log \mathcal{P}(\pi|s)]
\]
Args |
reduction
|
(Optional) The tf.keras.losses.Reduction to use (see
tf.keras.losses.Loss ).
|
name
|
(Optional) The name for the op.
|
ragged
|
(Optional) If True, this loss will accept ragged tensors. If
False, this loss will accept dense tensors.
|
sample_size
|
(Optional) Number of permutations to sample from teacher
scores. Defaults to 8.
|
topk
|
(Optional) top-k entries over which order is matched. A penalty is
applied over non top-k items. Defaults to None , which treats top-k as
all entries in the list.
|
temperature
|
(Optional) A float number to modify the logits as
logits=logits/temperature . Defaults to 1.
|
Methods
from_config
@classmethod
from_config(
config
)
Instantiates a Loss
from its config (output of get_config()
).
Args |
config
|
Output of get_config() .
|
get_config
View source
get_config() -> Dict[str, Any]
Returns the config dictionary for a Loss
instance.
__call__
View source
__call__(
y_true: tfr.keras.model.TensorLike
,
y_pred: tfr.keras.model.TensorLike
,
sample_weight: Optional[utils.TensorLike] = None
) -> tf.Tensor
See tf.keras.losses.Loss.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-10-20 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-10-20 UTC."],[],[],null,["# tfr.keras.losses.CoupledRankDistilLoss\n\n|----------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/losses.py#L1545-L1634) |\n\nComputes the Rank Distil loss between `y_true` and `y_pred`. \n\n tfr.keras.losses.CoupledRankDistilLoss(\n reduction: tf.losses.Reduction = tf.losses.Reduction.AUTO,\n name: Optional[str] = None,\n ragged: bool = False,\n sample_size: int = 8,\n topk: Optional[int] = None,\n temperature: Optional[float] = 1.0\n )\n\nThe Coupled-RankDistil loss ([Reddi et al, 2021](https://research.google/pubs/pub50695/)) is the\ncross-entropy between k-Plackett's probability of logits (student) and labels\n(teacher).\n\n#### Standalone usage:\n\n tf.random.set_seed(1)\n y_true = [[0., 2., 1.], [1., 0., 2.]]\n ln = tf.math.log\n y_pred = [[0., ln(3.), ln(2.)], [0., ln(2.), ln(3.)]]\n loss = tfr.keras.losses.CoupledRankDistilLoss(topk=2, sample_size=1)\n loss(y_true, y_pred).numpy()\n 2.138333\n\nUsage with the `compile()` API: \n\n model.compile(optimizer='sgd',\n loss=tfr.keras.losses.CoupledRankDistilLoss())\n\n#### Definition:\n\nThe k-Plackett's probability model is defined as:\n\n\\\\\\[\n\\\\mathcal{P}_k(\\\\pi\\|s) = \\\\frac{1}{(N-k)!} \\\\\\\\\n\\\\frac{\\\\prod_{i=1}\\^k exp(s_{\\\\pi(i)})}{\\\\sum_{j=k}\\^N log(exp(s_{\\\\pi(i)}))}.\n\\\\\\]\n\nThe Coupled-RankDistil loss is defined as:\n\n\\\\\\[\n\\\\mathcal{L}(y, s) = -\\\\sum_{\\\\pi} \\\\mathcal{P}_k(\\\\pi\\|y) log\\\\mathcal{P}(\\\\pi\\|s) \\\\\\\\\n= \\\\mathcal{E}_{\\\\pi \\\\sim \\\\matcal{P}(.\\|y)} \\[-\\\\log \\\\mathcal{P}(\\\\pi\\|s)\\]\n\\\\\\]\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| References ---------- ||\n|---|---|\n| \u003cbr /\u003e - [RankDistil: Knowledge Distillation for Ranking, Reddi et al, 2021](https://research.google/pubs/pub50695/) ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `reduction` | (Optional) The [`tf.keras.losses.Reduction`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Reduction) to use (see [`tf.keras.losses.Loss`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss)). |\n| `name` | (Optional) The name for the op. |\n| `ragged` | (Optional) If True, this loss will accept ragged tensors. If False, this loss will accept dense tensors. |\n| `sample_size` | (Optional) Number of permutations to sample from teacher scores. Defaults to 8. |\n| `topk` | (Optional) top-k entries over which order is matched. A penalty is applied over non top-k items. Defaults to `None`, which treats top-k as all entries in the list. |\n| `temperature` | (Optional) A float number to modify the logits as `logits=logits/temperature`. Defaults to 1. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `from_config`\n\n @classmethod\n from_config(\n config\n )\n\nInstantiates a `Loss` from its config (output of `get_config()`).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|----------|---------------------------|\n| `config` | Output of `get_config()`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A `Loss` instance. ||\n\n\u003cbr /\u003e\n\n### `get_config`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/losses.py#L1627-L1634) \n\n get_config() -\u003e Dict[str, Any]\n\nReturns the config dictionary for a `Loss` instance.\n\n### `__call__`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/losses.py#L262-L270) \n\n __call__(\n y_true: ../../../tfr/keras/model/TensorLike,\n y_pred: ../../../tfr/keras/model/TensorLike,\n sample_weight: Optional[utils.TensorLike] = None\n ) -\u003e tf.Tensor\n\nSee tf.keras.losses.Loss."]]