|
|
|
@ -297,6 +297,67 @@ def _check_label_dtype(labels_dtype, cls_name):
|
|
|
|
|
validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DiceLoss(_Loss):
|
|
|
|
|
r"""
|
|
|
|
|
The Dice coefficient is a set similarity loss. It is used to calculate the similarity between two samples. The
|
|
|
|
|
value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result
|
|
|
|
|
is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area.
|
|
|
|
|
The function is shown as follows:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
dice = 1 - \frac{2 * (pred \bigcap true)}{pred \bigcup true}
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
|
|
|
|
|
Default: 1e-5.
|
|
|
|
|
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **y_pred** (Tensor) - Tensor of shape (N, C).
|
|
|
|
|
- **y** (Tensor) - Tensor of shape (N, C).
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, a tensor of shape with the per-example sampled Dice losses.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
|
``Ascend``
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> loss = nn.Diceloss(smooth=1e-5, threshold=0.5)
|
|
|
|
|
>>> y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
|
|
|
|
|
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
|
|
|
|
|
>>> output = loss(y_pred, y)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
[0.77777076]
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, smooth=1e-5, threshold=0.5):
|
|
|
|
|
super(DiceLoss, self).__init__()
|
|
|
|
|
self.smooth = validator.check_positive_float(smooth, "smooth")
|
|
|
|
|
self.threshold = validator.check_value_type("threshold", threshold, [float])
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
|
|
|
|
|
def construct(self, logits, label):
|
|
|
|
|
_check_shape(logits.shape, label.shape)
|
|
|
|
|
logits = self.cast((logits > self.threshold), mstype.float32)
|
|
|
|
|
label = self.cast(label, mstype.float32)
|
|
|
|
|
dim = label.shape
|
|
|
|
|
pred_flat = self.reshape(logits, (dim[0], -1))
|
|
|
|
|
true_flat = self.reshape(label, (dim[0], -1))
|
|
|
|
|
|
|
|
|
|
intersection = self.reduce_sum((pred_flat * true_flat), 1)
|
|
|
|
|
unionset = self.reduce_sum(pred_flat, 1) + self.reduce_sum(true_flat, 1)
|
|
|
|
|
|
|
|
|
|
dice = (2 * intersection + self.smooth) / (unionset + self.smooth)
|
|
|
|
|
dice_loss = 1 - self.reduce_sum(dice) / dim[0]
|
|
|
|
|
|
|
|
|
|
return dice_loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_shape(logits_shape, label_shape):
|
|
|
|
|
validator.check('logits_shape', logits_shape, 'label_shape', label_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SampledSoftmaxLoss(_Loss):
|
|
|
|
|
r"""
|
|
|
|
|
Computes the sampled softmax training loss.
|
|
|
|
|