|
|
|
@ -436,7 +436,7 @@ class DiceLoss(_Loss):
|
|
|
|
|
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
|
|
|
|
>>> output = loss(y_pred, y)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
0.38596618
|
|
|
|
|
[0.38596618]
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, smooth=1e-5):
|
|
|
|
|
super(DiceLoss, self).__init__()
|
|
|
|
@ -1027,6 +1027,12 @@ def _check_channel_and_shape(predict, target):
|
|
|
|
|
f"inferred from 'predict': C={predict}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_input_dtype(targets_dtype, cls_name):
|
|
|
|
|
validator.check_type_name("targets", targets_dtype, [mstype.int32, mstype.int64, mstype.float16,
|
|
|
|
|
mstype.float32], cls_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(_Loss):
|
|
|
|
|
r"""
|
|
|
|
|
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
|
|
|
|
@ -1089,11 +1095,14 @@ class FocalLoss(_Loss):
|
|
|
|
|
self.squeeze = P.Squeeze(axis=1)
|
|
|
|
|
self.tile = P.Tile()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.dtype = P.DType()
|
|
|
|
|
self.logsoftmax = nn.LogSoftmax(1)
|
|
|
|
|
|
|
|
|
|
def construct(self, predict, target):
|
|
|
|
|
targets = target
|
|
|
|
|
_check_ndim(predict.ndim, targets.ndim)
|
|
|
|
|
_check_channel_and_shape(predict.shape[1], targets.shape[1])
|
|
|
|
|
_check_input_dtype(self.dtype(targets), self.cls_name)
|
|
|
|
|
|
|
|
|
|
if predict.ndim > 2:
|
|
|
|
|
predict = predict.view(predict.shape[0], predict.shape[1], -1)
|
|
|
|
@ -1102,7 +1111,7 @@ class FocalLoss(_Loss):
|
|
|
|
|
predict = self.expand_dims(predict, 2)
|
|
|
|
|
targets = self.expand_dims(targets, 2)
|
|
|
|
|
|
|
|
|
|
log_probability = nn.LogSoftmax(1)(predict)
|
|
|
|
|
log_probability = self.logsoftmax(predict)
|
|
|
|
|
|
|
|
|
|
if target.shape[1] == 1:
|
|
|
|
|
log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32))
|
|
|
|
@ -1116,7 +1125,7 @@ class FocalLoss(_Loss):
|
|
|
|
|
if target.shape[1] == 1:
|
|
|
|
|
convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32))
|
|
|
|
|
convert_weight = self.squeeze(convert_weight)
|
|
|
|
|
probability = log_probability * convert_weight
|
|
|
|
|
log_probability = log_probability * convert_weight
|
|
|
|
|
|
|
|
|
|
weight = F.pows(-probability + 1.0, self.gamma)
|
|
|
|
|
if target.shape[1] == 1:
|
|
|
|
|