View source on GitHub |
Zeroes the logits of accidental negatives.
tfrs.layers.loss.RemoveAccidentalHits(
trainable=True, name=None, dtype=None, dynamic=False, **kwargs
)
Methods
call
call(
labels: tf.Tensor, logits: tf.Tensor, candidate_ids: tf.Tensor
) -> tf.Tensor
Zeros selected logits.
For each row in the batch, zeros the logits of negative candidates that have the same id as the positive candidate in that row.
Args | |
---|---|
labels
|
[batch_size, num_candidates] one-hot labels tensor. |
logits
|
[batch_size, num_candidates] logits tensor. |
candidate_ids
|
[num_candidates] candidate identifiers tensor |
Returns | |
---|---|
logits
|
Modified logits. |