|
|
|
@ -13,11 +13,13 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""loss"""
|
|
|
|
|
import mindspore
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.common.parameter import Parameter
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
|
from mindspore import nn
|
|
|
|
|
from mindspore.ops.primitive import constexpr
|
|
|
|
|
from mindspore.ops import _selected_ops
|
|
|
|
|
from mindspore.nn.cell import Cell
|
|
|
|
@ -896,3 +898,110 @@ class BCEWithLogitsLoss(_Loss):
|
|
|
|
|
pos_weight = ones_input
|
|
|
|
|
loss = self.bce_with_logits_loss(predict, target, weight, pos_weight)
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_ndim(predict_nidm, target_ndim):
|
|
|
|
|
validator.check_int(predict_nidm, target_ndim, Rel.EQ, 'predict_nidm', 'target_ndim')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_channel_and_shape(target, predict):
|
|
|
|
|
if target not in (predict, 1):
|
|
|
|
|
raise ValueError("The target must have a channel or the same shape as predict.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@constexpr
|
|
|
|
|
def _check_predict_channel(predict):
|
|
|
|
|
if predict == 1:
|
|
|
|
|
raise NotImplementedError("Single channel prediction is not supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FocalLoss(_Loss):
|
|
|
|
|
r"""
|
|
|
|
|
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
|
|
|
|
|
effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of
|
|
|
|
|
classification difficulty.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
gamma (float): Gamma is used to adjust the steepness of weight curve in focal loss. Default: 2.0.
|
|
|
|
|
weight (Union[Tensor, None]): A rescaling weight applied to the loss of each batch element. If None, no weights
|
|
|
|
|
are applied. Default: None.
|
|
|
|
|
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
|
|
|
|
|
If "none", do not perform reduction. Default: "mean".
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **predict** (Tensor) - Input logits. Tensor of shape should be BCH[WD]. Where C is the number of classes.
|
|
|
|
|
Its value is greater than 1.
|
|
|
|
|
- **target** (Tensor) - Tensor of shape should be B1H[WD] or BCH[WD]. If the target shape is B1H[WD], the
|
|
|
|
|
expected target of this loss should be the class index within the range of [0, C-1],
|
|
|
|
|
where C is the number of classes.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, a tensor of shape with the per-example sampled Focal losses.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If the data type of ``gamma`` is not float..
|
|
|
|
|
TypeError: If ``weight`` is not a Parameter.
|
|
|
|
|
ValueError: If ``target`` shape different from ``predict``.
|
|
|
|
|
ValueError: If ``target`` channel is not 1 and ``target`` shape is different from ``predict``.
|
|
|
|
|
ValueError: If ``reduction`` is not one of 'none', 'mean', 'sum'.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
>>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
|
|
|
|
|
>>> target = Tensor([[1], [1], [0]], mstype.int32)
|
|
|
|
|
>>> focalloss = nn.FocalLoss(weight=Tensor([1, 2]), gamma=2.0, reduction='mean')
|
|
|
|
|
>>> output = focalloss(inputs, labels)
|
|
|
|
|
>>> print(output)
|
|
|
|
|
0.33365273
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, weight=None, gamma=2.0, reduction='mean'):
|
|
|
|
|
super(FocalLoss, self).__init__(reduction=reduction)
|
|
|
|
|
|
|
|
|
|
self.gamma = validator.check_value_type("gamma", gamma, [float])
|
|
|
|
|
if weight is not None and not isinstance(weight, Tensor):
|
|
|
|
|
raise TypeError("The type of weight should be Tensor, but got {}.".format(type(weight)))
|
|
|
|
|
self.weight = weight
|
|
|
|
|
self.expand_dims = P.ExpandDims()
|
|
|
|
|
self.gather_d = P.GatherD()
|
|
|
|
|
self.squeeze = P.Squeeze(axis=1)
|
|
|
|
|
self.tile = P.Tile()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
|
|
|
|
|
def construct(self, predict, target):
|
|
|
|
|
targets = target
|
|
|
|
|
_check_ndim(predict.ndim, targets.ndim)
|
|
|
|
|
_check_channel_and_shape(targets.shape[1], predict.shape[1])
|
|
|
|
|
_check_predict_channel(predict.shape[1])
|
|
|
|
|
|
|
|
|
|
if predict.ndim > 2:
|
|
|
|
|
predict = predict.view(predict.shape[0], predict.shape[1], -1)
|
|
|
|
|
targets = targets.view(targets.shape[0], targets.shape[1], -1)
|
|
|
|
|
else:
|
|
|
|
|
predict = self.expand_dims(predict, 2)
|
|
|
|
|
targets = self.expand_dims(targets, 2)
|
|
|
|
|
|
|
|
|
|
log_probability = nn.LogSoftmax(1)(predict)
|
|
|
|
|
|
|
|
|
|
if target.shape[1] == 1:
|
|
|
|
|
log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32))
|
|
|
|
|
log_probability = self.squeeze(log_probability)
|
|
|
|
|
|
|
|
|
|
probability = F.exp(log_probability)
|
|
|
|
|
|
|
|
|
|
if self.weight is not None:
|
|
|
|
|
convert_weight = self.weight[None, :, None]
|
|
|
|
|
convert_weight = self.tile(convert_weight, (targets.shape[0], 1, targets.shape[2]))
|
|
|
|
|
if target.shape[1] == 1:
|
|
|
|
|
convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32))
|
|
|
|
|
convert_weight = self.squeeze(convert_weight)
|
|
|
|
|
probability = log_probability * convert_weight
|
|
|
|
|
|
|
|
|
|
weight = F.pows(-probability + 1.0, self.gamma)
|
|
|
|
|
if target.shape[1] == 1:
|
|
|
|
|
loss = (-weight * log_probability).mean(axis=1)
|
|
|
|
|
else:
|
|
|
|
|
loss = (-weight * targets * log_probability).mean(axis=-1)
|
|
|
|
|
|
|
|
|
|
return self.get_loss(loss)
|
|
|
|
|