View source on GitHub |
Computes the Gumbel approximate NDCG loss between y_true
and y_pred
.
Inherits From: ApproxNDCGLoss
tfr.keras.losses.GumbelApproxNDCGLoss(
reduction: tf.losses.Reduction = tf.losses.Reduction.AUTO,
name: Optional[str] = None,
lambda_weight: Optional[losses_impl._LambdaWeight] = None,
temperature: float = 0.1,
sample_size: int = 8,
gumbel_temperature: float = 1.0,
seed: Optional[int] = None,
ragged: bool = False
)
Implementation of Gumbel ApproxNDCG loss (Bruch et al, 2020).
This loss is the same as tfr.keras.losses.ApproxNDCGLoss
but where logits
are sampled from the Gumbel distribution:
y_new_pred ~ Gumbel(y_pred, 1 / temperature)
Standalone usage:
tf.random.set_seed(42)
y_true = [[1., 0.]]
y_pred = [[0.6, 0.8]]
loss = tfr.keras.losses.GumbelApproxNDCGLoss(seed=42)
loss(y_true, y_pred).numpy()
-0.8160851
# Using a higher gumbel temperature
loss = tfr.keras.losses.GumbelApproxNDCGLoss(gumbel_temperature=2.0,
seed=42)
loss(y_true, y_pred).numpy()
-0.7583889
# Using ragged tensors
y_true = tf.ragged.constant([[1., 0.], [0., 1., 0.]])
y_pred = tf.ragged.constant([[0.6, 0.8], [0.5, 0.8, 0.4]])
loss = tfr.keras.losses.GumbelApproxNDCGLoss(seed=42, ragged=True)
loss(y_true, y_pred).numpy()
-0.6987189
Usage with the compile()
API:
model.compile(optimizer='sgd', loss=tfr.keras.losses.GumbelApproxNDCGLoss())
Definition:
\[\mathcal{L}(\{y\}, \{s\}) = \text{ApproxNDCGLoss}(\{y\}, \{z\})\]
where
\[ z \sim \text{Gumbel}(s, \beta)\\ p(z) = e^{-t-e^{-t} }\\ t = \beta(z - s)\\ \beta = \frac{1}{\text{temperature} } \]
References | |
---|---|
Args | |
---|---|
reduction
|
Type of tf.keras.losses.Reduction to apply to
loss. Default value is AUTO . AUTO indicates that the
reduction option will be determined by the usage context. For
almost all cases this defaults to SUM_OVER_BATCH_SIZE . When
used under a tf.distribute.Strategy , except via
Model.compile() and Model.fit() , using AUTO or
SUM_OVER_BATCH_SIZE will raise an error. Please see this
custom training tutorial
for more details.
|
name
|
Optional name for the instance. |
Methods
from_config
@classmethod
from_config( config, custom_objects=None )
Instantiates a Loss
from its config (output of get_config()
).
Args | |
---|---|
config
|
Output of get_config() .
|
Returns | |
---|---|
A Loss instance.
|
get_config
get_config() -> Dict[str, Any]
Returns the config dictionary for a Loss
instance.
__call__
__call__(
y_true: tfr.keras.model.TensorLike
,
y_pred: tfr.keras.model.TensorLike
,
sample_weight: Optional[utils.TensorLike] = None
) -> tf.Tensor
See _RankingLoss.