Transforms logits and labels to return hard negatives.
tfrs.layers.loss.HardNegativeMining(
num_hard_negatives: int
) -> None
Args |
num_hard_negatives
|
How many hard negatives to return.
|
Methods
call
View source
call(
logits: tf.Tensor, labels: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]
Filters logits and labels with per-query hard negative mining.
The result will include logits and labels for num_hard_negatives
negatives as well as the positive candidate.
Args |
logits
|
[batch_size, number_of_candidates] tensor of logits.
|
labels
|
[batch_size, number_of_candidates] one-hot tensor of labels.
|
Returns |
logits
|
[batch_size, num_hard_negatives + 1] tensor of logits.
|
labels
|
[batch_size, num_hard_negatives + 1] one-hot tensor of labels.
|