add focal loss

pull/12320/head
Jiaqi 4 years ago
parent 9890812cce
commit fd1a4726b7

@ -19,11 +19,11 @@ Cells of loss function. Loss function in machine learning is the target of the m
It shows how well the model works on a dataset and the optimization target which the optimizer is searching. It shows how well the model works on a dataset and the optimization target which the optimizer is searching.
""" """
from .loss import L1Loss, MSELoss, SmoothL1Loss, \ from .loss import L1Loss, MSELoss, SmoothL1Loss, FocalLoss,\
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss
__all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', 'FocalLoss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss'] 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss']

@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""loss""" """loss"""
import mindspore
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import nn
from mindspore.ops.primitive import constexpr from mindspore.ops.primitive import constexpr
from mindspore.ops import _selected_ops from mindspore.ops import _selected_ops
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
@ -897,3 +899,110 @@ class BCEWithLogitsLoss(_Loss):
pos_weight = ones_input pos_weight = ones_input
loss = self.bce_with_logits_loss(predict, target, weight, pos_weight) loss = self.bce_with_logits_loss(predict, target, weight, pos_weight)
return loss return loss
@constexpr
def _check_ndim(predict_nidm, target_ndim):
validator.check_int(predict_nidm, target_ndim, Rel.EQ, 'predict_nidm', 'target_ndim')
@constexpr
def _check_channel_and_shape(target, predict):
if target not in (predict, 1):
raise ValueError("The target must have a channel or the same shape as predict.")
@constexpr
def _check_predict_channel(predict):
if predict == 1:
raise NotImplementedError("Single channel prediction is not supported.")
class FocalLoss(_Loss):
r"""
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of
classification difficulty.
Args:
gamma (float): Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.
weight (Union[Tensor, None]): A rescaling weight applied to the loss of each batch element. If None, no weights
are applied. Default: None.
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
If "none", do not perform reduction. Default: "mean".
Inputs:
- **predict** (Tensor) - Input logits. Tensor of shape should be BCH[WD]. Where C is the number of classes.
Its value is greater than 1.
- **target** (Tensor) - Tensor of shape should be B1H[WD] or BCH[WD]. If the target shape is B1H[WD], the
expected target of this loss should be the class index within the range of [0, C-1],
where C is the number of classes.
Outputs:
Tensor, a tensor of shape with the per-example sampled Focal losses.
Raises:
TypeError: If the data type of ``gamma`` is not float..
TypeError: If ``weight`` is not a Parameter.
ValueError: If ``target`` shape different from ``predict``.
ValueError: If ``target`` channel is not 1 and ``target`` shape is different from ``predict``.
ValueError: If ``reduction`` is not one of 'none', 'mean', 'sum'.
Example:
>>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
>>> target = Tensor([[1], [1], [0]], mstype.int32)
>>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean')
>>> output = focalloss(inputs, labels)
>>> print(output)
0.33365273
"""
def __init__(self, weight=None, gamma=2.0, reduction='mean'):
super(FocalLoss, self).__init__(reduction=reduction)
self.gamma = validator.check_value_type("gamma", gamma, [float])
if weight is not None and not isinstance(weight, Tensor):
raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight)))
self.weight = weight
self.expand_dims = P.ExpandDims()
self.gather_d = P.GatherD()
self.squeeze = P.Squeeze(axis=1)
self.tile = P.Tile()
self.cast = P.Cast()
def construct(self, predict, target):
targets = target
_check_ndim(predict.ndim, targets.ndim)
_check_channel_and_shape(targets.shape[1], predict.shape[1])
_check_predict_channel(predict.shape[1])
if predict.ndim > 2:
predict = predict.view(predict.shape[0], predict.shape[1], -1)
targets = targets.view(targets.shape[0], targets.shape[1], -1)
else:
predict = self.expand_dims(predict, 2)
targets = self.expand_dims(targets, 2)
log_probability = nn.LogSoftmax(1)(predict)
if target.shape[1] == 1:
log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32))
log_probability = self.squeeze(log_probability)
probability = F.exp(log_probability)
if self.weight is not None:
convert_weight = self.weight[None, :, None]
convert_weight = self.tile(convert_weight, (targets.shape[0], 1, targets.shape[2]))
if target.shape[1] == 1:
convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32))
convert_weight = self.squeeze(convert_weight)
probability = log_probability * convert_weight
weight = F.pows(-probability + 1.0, self.gamma)
if target.shape[1] == 1:
loss = (-weight * log_probability).mean(axis=1)
else:
loss = (-weight * targets * log_probability).mean(axis=-1)
return self.get_loss(loss)

@ -52,18 +52,30 @@ geswitch = P.GeSwitch()
addn = P.AddN() addn = P.AddN()
absolute = P.Abs() absolute = P.Abs()
tensor_add = P.Add() tensor_add = P.Add()
add = tensor_add
neg_tensor = P.Neg() neg_tensor = P.Neg()
tensor_lt = P.Less() tensor_lt = P.Less()
less = tensor_lt
tensor_le = P.LessEqual() tensor_le = P.LessEqual()
le = tensor_le
tensor_gt = P.Greater() tensor_gt = P.Greater()
gt = tensor_gt
tensor_ge = P.GreaterEqual() tensor_ge = P.GreaterEqual()
ge = tensor_ge
tensor_sub = P.Sub() tensor_sub = P.Sub()
sub = tensor_sub
tensor_mul = P.Mul() tensor_mul = P.Mul()
mul = tensor_mul
tensor_div = P.RealDiv() tensor_div = P.RealDiv()
div = tensor_div
tensor_floordiv = P.FloorDiv() tensor_floordiv = P.FloorDiv()
floordiv = tensor_floordiv
tensor_pow = P.Pow() tensor_pow = P.Pow()
pows = tensor_pow
tensor_mod = P.FloorMod() tensor_mod = P.FloorMod()
floormod = tensor_mod
tensor_exp = P.Exp() tensor_exp = P.Exp()
exp = tensor_exp
tensor_expm1 = P.Expm1() tensor_expm1 = P.Expm1()
strided_slice = P.StridedSlice() strided_slice = P.StridedSlice()
same_type_shape = P.SameTypeShape() same_type_shape = P.SameTypeShape()

@ -91,6 +91,50 @@ def test_cosine_embedding_loss():
loss(x1, x2, label) loss(x1, x2, label)
def test_focal_loss():
""" test_FocalLoss """
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
x2 = Tensor([[1], [1], [0]], mstype.int32)
focalloss = nn.FocalLoss()
focalloss(x1, x2)
def test_focal_loss_gamma():
""" test_FocalLoss """
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
x2 = Tensor([[1], [1], [0]], mstype.int32)
with pytest.raises(TypeError):
focalloss = nn.FocalLoss(weight=None, gamma="mmm", reduction='mean')
focalloss(x1, x2)
def test_focal_loss_weight():
""" test_FocalLoss """
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
x2 = Tensor([[1], [1]], mstype.int32)
with pytest.raises(TypeError):
focalloss = nn.FocalLoss(weight='a', gamma=2.0, reduction='mean')
focalloss(x1, x2)
def test_focal_loss_reduction():
""" test_FocalLoss """
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
x2 = Tensor([[1], [1], [0]], mstype.int32)
with pytest.raises(ValueError):
focalloss = nn.FocalLoss(weight=None, gamma=2.0, reduction='m')
focalloss(x1, x2)
def test_focal_loss_input():
""" test_FocalLoss """
x1 = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
x2 = Tensor([[1]], mstype.int32)
focalloss = nn.FocalLoss(weight=None, gamma=2.0, reduction='mean')
with pytest.raises(ValueError):
focalloss(x1, x2)
def test_dice_loss(): def test_dice_loss():
""" test_dice_loss """ """ test_dice_loss """
loss = nn.DiceLoss() loss = nn.DiceLoss()

Loading…
Cancel
Save