From fd1a4726b77a8218a0c32895a5017a874a3456b1 Mon Sep 17 00:00:00 2001 From: Jiaqi Date: Mon, 22 Feb 2021 16:27:55 +0800 Subject: [PATCH] add focal loss --- mindspore/nn/loss/__init__.py | 4 +- mindspore/nn/loss/loss.py | 109 ++++++++++++++++++++++++++++++++ mindspore/ops/functional.py | 12 ++++ tests/ut/python/nn/test_loss.py | 44 +++++++++++++ 4 files changed, 167 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/loss/__init__.py b/mindspore/nn/loss/__init__.py index c64f96c301..df5f6d26f1 100644 --- a/mindspore/nn/loss/__init__.py +++ b/mindspore/nn/loss/__init__.py @@ -19,11 +19,11 @@ Cells of loss function. Loss function in machine learning is the target of the m It shows how well the model works on a dataset and the optimization target which the optimizer is searching. """ -from .loss import L1Loss, MSELoss, SmoothL1Loss, \ +from .loss import L1Loss, MSELoss, SmoothL1Loss, FocalLoss,\ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss -__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', +__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', 'FocalLoss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss'] diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 629f351c95..1465488d13 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -13,11 +13,13 @@ # limitations under the License. # ============================================================================ """loss""" +import mindspore import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.common.parameter import Parameter from mindspore.ops import operations as P from mindspore.ops import functional as F +from mindspore import nn from mindspore.ops.primitive import constexpr from mindspore.ops import _selected_ops from mindspore.nn.cell import Cell @@ -897,3 +899,110 @@ class BCEWithLogitsLoss(_Loss): pos_weight = ones_input loss = self.bce_with_logits_loss(predict, target, weight, pos_weight) return loss + + +@constexpr +def _check_ndim(predict_nidm, target_ndim): + validator.check_int(predict_nidm, target_ndim, Rel.EQ, 'predict_nidm', 'target_ndim') + + +@constexpr +def _check_channel_and_shape(target, predict): + if target not in (predict, 1): + raise ValueError("The target must have a channel or the same shape as predict.") + + +@constexpr +def _check_predict_channel(predict): + if predict == 1: + raise NotImplementedError("Single channel prediction is not supported.") + + +class FocalLoss(_Loss): + r""" + The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the + effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of + classification difficulty. + + Args: + gamma (float): Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0. + weight (Union[Tensor, None]): A rescaling weight applied to the loss of each batch element. If None, no weights + are applied. Default: None. + reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none". + If "none", do not perform reduction. Default: "mean". + + Inputs: + - **predict** (Tensor) - Input logits. Tensor of shape should be BCH[WD]. Where C is the number of classes. + Its value is greater than 1. + - **target** (Tensor) - Tensor of shape should be B1H[WD] or BCH[WD]. If the target shape is B1H[WD], the + expected target of this loss should be the class index within the range of [0, C-1], + where C is the number of classes. + + Outputs: + Tensor, a tensor of shape with the per-example sampled Focal losses. + + Raises: + TypeError: If the data type of ``gamma`` is not float.. + TypeError: If ``weight`` is not a Parameter. + ValueError: If ``target`` shape different from ``predict``. + ValueError: If ``target`` channel is not 1 and ``target`` shape is different from ``predict``. + ValueError: If ``reduction`` is not one of 'none', 'mean', 'sum'. + + Example: + >>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) + >>> target = Tensor([[1], [1], [0]], mstype.int32) + >>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean') + >>> output = focalloss(inputs, labels) + >>> print(output) + 0.33365273 + """ + + def __init__(self, weight=None, gamma=2.0, reduction='mean'): + super(FocalLoss, self).__init__(reduction=reduction) + + self.gamma = validator.check_value_type("gamma", gamma, [float]) + if weight is not None and not isinstance(weight, Tensor): + raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight))) + self.weight = weight + self.expand_dims = P.ExpandDims() + self.gather_d = P.GatherD() + self.squeeze = P.Squeeze(axis=1) + self.tile = P.Tile() + self.cast = P.Cast() + + def construct(self, predict, target): + targets = target + _check_ndim(predict.ndim, targets.ndim) + _check_channel_and_shape(targets.shape[1], predict.shape[1]) + _check_predict_channel(predict.shape[1]) + + if predict.ndim > 2: + predict = predict.view(predict.shape[0], predict.shape[1], -1) + targets = targets.view(targets.shape[0], targets.shape[1], -1) + else: + predict = self.expand_dims(predict, 2) + targets = self.expand_dims(targets, 2) + + log_probability = nn.LogSoftmax(1)(predict) + + if target.shape[1] == 1: + log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32)) + log_probability = self.squeeze(log_probability) + + probability = F.exp(log_probability) + + if self.weight is not None: + convert_weight = self.weight[None, :, None] + convert_weight = self.tile(convert_weight, (targets.shape[0], 1, targets.shape[2])) + if target.shape[1] == 1: + convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32)) + convert_weight = self.squeeze(convert_weight) + probability = log_probability * convert_weight + + weight = F.pows(-probability + 1.0, self.gamma) + if target.shape[1] == 1: + loss = (-weight * log_probability).mean(axis=1) + else: + loss = (-weight * targets * log_probability).mean(axis=-1) + + return self.get_loss(loss) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 21d6c906c7..28e996d683 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -52,18 +52,30 @@ geswitch = P.GeSwitch() addn = P.AddN() absolute = P.Abs() tensor_add = P.Add() +add = tensor_add neg_tensor = P.Neg() tensor_lt = P.Less() +less = tensor_lt tensor_le = P.LessEqual() +le = tensor_le tensor_gt = P.Greater() +gt = tensor_gt tensor_ge = P.GreaterEqual() +ge = tensor_ge tensor_sub = P.Sub() +sub = tensor_sub tensor_mul = P.Mul() +mul = tensor_mul tensor_div = P.RealDiv() +div = tensor_div tensor_floordiv = P.FloorDiv() +floordiv = tensor_floordiv tensor_pow = P.Pow() +pows = tensor_pow tensor_mod = P.FloorMod() +floormod = tensor_mod tensor_exp = P.Exp() +exp = tensor_exp tensor_expm1 = P.Expm1() strided_slice = P.StridedSlice() same_type_shape = P.SameTypeShape() diff --git a/tests/ut/python/nn/test_loss.py b/tests/ut/python/nn/test_loss.py index c19f79c164..747353991f 100644 --- a/tests/ut/python/nn/test_loss.py +++ b/tests/ut/python/nn/test_loss.py @@ -91,6 +91,50 @@ def test_cosine_embedding_loss(): loss(x1, x2, label) +def test_focal_loss(): + """ test_FocalLoss """ + x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) + x2 = Tensor([[1], [1], [0]], mstype.int32) + focalloss = nn.FocalLoss() + focalloss(x1, x2) + + +def test_focal_loss_gamma(): + """ test_FocalLoss """ + x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) + x2 = Tensor([[1], [1], [0]], mstype.int32) + with pytest.raises(TypeError): + focalloss = nn.FocalLoss(weight=None, gamma="mmm", reduction='mean') + focalloss(x1, x2) + + +def test_focal_loss_weight(): + """ test_FocalLoss """ + x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) + x2 = Tensor([[1], [1]], mstype.int32) + with pytest.raises(TypeError): + focalloss = nn.FocalLoss(weight='a', gamma=2.0, reduction='mean') + focalloss(x1, x2) + + +def test_focal_loss_reduction(): + """ test_FocalLoss """ + x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) + x2 = Tensor([[1], [1], [0]], mstype.int32) + with pytest.raises(ValueError): + focalloss = nn.FocalLoss(weight=None, gamma=2.0, reduction='m') + focalloss(x1, x2) + + +def test_focal_loss_input(): + """ test_FocalLoss """ + x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) + x2 = Tensor([[1]], mstype.int32) + focalloss = nn.FocalLoss(weight=None, gamma=2.0, reduction='mean') + with pytest.raises(ValueError): + focalloss(x1, x2) + + def test_dice_loss(): """ test_dice_loss """ loss = nn.DiceLoss()