Transforms logits and labels to return hard negatives.
tfrs.layers.loss.HardNegativeMining(
num_hard_negatives: int
) -> None
Methods
call
call(
logits: tf.Tensor, labels: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]
Filters logits and labels with per-query hard negative mining.
The result will include logits and labels for num_hard_negatives negatives as well as the positive candidate.
Args | |
---|---|
logits
|
[batch_size, number_of_candidates] tensor of logits. |
labels
|
[batch_size, number_of_candidates] one-hot tensor of labels. |
Returns | |
---|---|
logits
|
[batch_size, num_hard_negatives + 1] tensor of logits. |
labels
|
[batch_size, num_hard_negatives + 1] one-hot tensor of labels. |