Implements the focal loss function.
@tf.function
tfa.losses.sigmoid_focal_crossentropy(
y_true: tfa.types.TensorLike
,
y_pred: tfa.types.TensorLike
,
alpha: tfa.types.FloatTensorLike
= 0.25,
gamma: tfa.types.FloatTensorLike
= 2.0,
from_logits: bool = False
) -> tf.Tensor
Focal loss was first introduced in the RetinaNet paper
(https://arxiv.org/pdf/1708.02002.pdf). Focal loss is extremely useful for
classification when you have highly imbalanced classes. It down-weights
well-classified examples and focuses on hard examples. The loss value is
much higher for a sample which is misclassified by the classifier as compared
to the loss value corresponding to a well-classified example. One of the
best use-cases of focal loss is its usage in object detection where the
imbalance between the background class and other classes is extremely high.
Args |
y_true
|
true targets tensor.
|
y_pred
|
predictions tensor.
|
alpha
|
balancing factor.
|
gamma
|
modulating factor.
|
Returns |
Weighted loss float Tensor . If reduction is NONE ,this has the
same shape as y_true ; otherwise, it is scalar.
|