View source on GitHub
|
Computes ListMLE loss between y_true and y_pred.
tfr.keras.losses.ListMLELoss(
reduction: tf.losses.Reduction = tf.losses.Reduction.AUTO,
name: Optional[str] = None,
lambda_weight: Optional[losses_impl._LambdaWeight] = None,
temperature: float = 1.0,
ragged: bool = False
)
Implements ListMLE loss (Xia et al, 2008). For each list of scores
s in y_pred and list of labels y in y_true:
loss = - log P(permutation_y | s)
P(permutation_y | s) = Plackett-Luce probability of permutation_y given s
permutation_y = permutation of items sorted by labels y.
Standalone usage:
tf.random.set_seed(42)y_true = [[1., 0.]]y_pred = [[0.6, 0.8]]loss = tfr.keras.losses.ListMLELoss()loss(y_true, y_pred).numpy()0.7981389
# Using ragged tensorstf.random.set_seed(42)y_true = tf.ragged.constant([[1., 0.], [0., 1., 0.]])y_pred = tf.ragged.constant([[0.6, 0.8], [0.5, 0.8, 0.4]])loss = tfr.keras.losses.ListMLELoss(ragged=True)loss(y_true, y_pred).numpy()1.1613163
Usage with the compile() API:
model.compile(optimizer='sgd', loss=tfr.keras.losses.ListMLELoss())
Definition:
\[ \mathcal{L}(\{y\}, \{s\}) = - \log(P(\pi_y | s)) \]
where \(P(\pi_y | s)\) is the Plackett-Luce probability of a permutation \(\pi_y\) conditioned on scores \(s\). Here \(\pi_y\) represents a permutation of items ordered by the relevance labels \(y\) where ties are broken randomly.
References | |
|---|---|
Args | |
|---|---|
reduction
|
(Optional) The tf.keras.losses.Reduction to use (see
tf.keras.losses.Loss).
|
name
|
(Optional) The name for the op. |
lambda_weight
|
(Optional) A lambdaweight to apply to the loss. Can be one
of tfr.keras.losses.DCGLambdaWeight,
tfr.keras.losses.NDCGLambdaWeight,
tfr.keras.losses.PrecisionLambdaWeight, or,
tfr.keras.losses.ListMLELambdaWeight.
|
temperature
|
(Optional) The temperature to use for scaling the logits. |
ragged
|
(Optional) If True, this loss will accept ragged tensors. If False, this loss will accept dense tensors. |
Methods
from_config
@classmethodfrom_config( config, custom_objects=None )
Instantiates a Loss from its config (output of get_config()).
| Args | |
|---|---|
config
|
Output of get_config().
|
| Returns | |
|---|---|
A Loss instance.
|
get_config
get_config() -> Dict[str, Any]
Returns the config dictionary for a Loss instance.
__call__
__call__(
y_true: tfr.keras.model.TensorLike,
y_pred: tfr.keras.model.TensorLike,
sample_weight: Optional[utils.TensorLike] = None
) -> tf.Tensor
See tf.keras.losses.Loss.
View source on GitHub