View source on GitHub |
Computes the Rank Distil loss between y_true
and y_pred
.
tfr.keras.losses.CoupledRankDistilLoss(
reduction: tf.losses.Reduction = tf.losses.Reduction.AUTO,
name: Optional[str] = None,
ragged: bool = False,
sample_size: int = 8,
topk: Optional[int] = None,
temperature: Optional[float] = 1.0
)
The Coupled-RankDistil loss (Reddi et al, 2021) is the cross-entropy between k-Plackett's probability of logits (student) and labels (teacher).
Standalone usage:
tf.random.set_seed(1)
y_true = [[0., 2., 1.], [1., 0., 2.]]
ln = tf.math.log
y_pred = [[0., ln(3.), ln(2.)], [0., ln(2.), ln(3.)]]
loss = tfr.keras.losses.CoupledRankDistilLoss(topk=2, sample_size=1)
loss(y_true, y_pred).numpy()
2.138333
Usage with the compile()
API:
model.compile(optimizer='sgd',
loss=tfr.keras.losses.CoupledRankDistilLoss())
Definition:
The k-Plackett's probability model is defined as:
\[ \mathcal{P}_k(\pi|s) = \frac{1}{(N-k)!} \\ \frac{\prod_{i=1}^k exp(s_{\pi(i)})}{\sum_{j=k}^N log(exp(s_{\pi(i)}))}. \]
The Coupled-RankDistil loss is defined as:
\[ \mathcal{L}(y, s) = -\sum_{\pi} \mathcal{P}_k(\pi|y) log\mathcal{P}(\pi|s) \\ = \mathcal{E}_{\pi \sim \matcal{P}(.|y)} [-\log \mathcal{P}(\pi|s)] \]
References | |
---|---|
Args | |
---|---|
reduction
|
(Optional) The tf.keras.losses.Reduction to use (see
tf.keras.losses.Loss ).
|
name
|
(Optional) The name for the op. |
ragged
|
(Optional) If True, this loss will accept ragged tensors. If False, this loss will accept dense tensors. |
sample_size
|
(Optional) Number of permutations to sample from teacher scores. Defaults to 8. |
topk
|
(Optional) top-k entries over which order is matched. A penalty is
applied over non top-k items. Defaults to None , which treats top-k as
all entries in the list.
|
temperature
|
(Optional) A float number to modify the logits as
logits=logits/temperature . Defaults to 1.
|
Methods
from_config
@classmethod
from_config( config )
Instantiates a Loss
from its config (output of get_config()
).
Args | |
---|---|
config
|
Output of get_config() .
|
Returns | |
---|---|
A Loss instance.
|
get_config
get_config() -> Dict[str, Any]
Returns the config dictionary for a Loss
instance.
__call__
__call__(
y_true: tfr.keras.model.TensorLike
,
y_pred: tfr.keras.model.TensorLike
,
sample_weight: Optional[utils.TensorLike] = None
) -> tf.Tensor
See tf.keras.losses.Loss.