diff --git a/mindspore/nn/loss/__init__.py b/mindspore/nn/loss/__init__.py index b0ead67f7b..1327943ff2 100644 --- a/mindspore/nn/loss/__init__.py +++ b/mindspore/nn/loss/__init__.py @@ -20,8 +20,8 @@ It shows how well the model works on a dataset and the optimization target which """ from .loss import L1Loss, MSELoss, SmoothL1Loss, \ - SoftmaxCrossEntropyWithLogits, CosineEmbeddingLoss + SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss __all__ = ['L1Loss', 'MSELoss', 'SmoothL1Loss', - 'SoftmaxCrossEntropyWithLogits', + 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'CosineEmbeddingLoss'] diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index bebfbe0b5b..7592fb6629 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -262,6 +262,67 @@ class SoftmaxCrossEntropyWithLogits(_Loss): return self.get_loss(x) +class BCELoss(_Loss): + r""" + BCELoss creates a criterion to measure the Binary Cross Entropy between the true labels and predicted labels. + + Note: + Set the predicted labels as :math:`x`, true labels as :math:`y`, the output loss as :math:`\ell(x, y)`. + Let, + + .. math:: + L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] + + Then, + + .. math:: + \ell(x, y) = \begin{cases} + L, & \text{if reduction} = \text{`none';}\\ + \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\ + \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.} + \end{cases} + + Args: + weight (Tensor, optional): A rescaling weight applied to the loss of each batch element. + And it must have same shape and data type as `inputs`. Default: None + reduction (str): Specifies the reduction to be applied to the output. + Its value must be one of 'none', 'mean', 'sum'. Default: 'none'. + + Inputs: + - **inputs** (Tensor) - The input Tensor. The data type must be float16 or float32. + - **labels** (Tensor) - The label Tensor which has same shape and data type as `inputs`. + + Outputs: + Tensor or Scalar, if `reduction` is 'none', then output is a tensor and has the same shape as `inputs`. + Otherwise, the output is a scalar. default: 'none' + + Examples: + >>> weight = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 3.3, 2.2]]), mindspore.float32) + >>> loss = nn.BCELoss(weight=weight, reduction='mean') + >>> inputs = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]), mindspore.float32) + >>> labels = Tensor(np.array([[0, 1, 0], [0, 0, 1]]), mindspore.float32) + >>> loss(inputs, labels) + """ + + def __init__(self, weight=None, reduction='none'): + super(BCELoss, self).__init__() + self.binary_cross_entropy = P.BinaryCrossEntropy(reduction=reduction) + self.weight_one = weight is None + if not self.weight_one: + self.weight = weight + else: + self.ones = P.OnesLike() + + def construct(self, inputs, labels): + if self.weight_one: + weight = self.ones(inputs) + else: + weight = self.weight + loss = self.binary_cross_entropy(inputs, labels, weight) + return loss + + @constexpr def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name): validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name) diff --git a/tests/ut/python/nn/test_loss.py b/tests/ut/python/nn/test_loss.py index f055443dbe..f4d97ef1ac 100644 --- a/tests/ut/python/nn/test_loss.py +++ b/tests/ut/python/nn/test_loss.py @@ -53,6 +53,34 @@ def test_SoftmaxCrossEntropyWithLogits_reduce(): loss(logits, labels) +def test_BCELoss(): + """ test_BCELoss """ + loss = nn.BCELoss() + + inputs_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]).astype(np.float32)) + target_data = Tensor(np.array([[0, 1, 0], [0, 0, 1]]).astype(np.float32)) + loss(inputs_data, target_data) + + +def test_BCELoss_reduce(): + """ test_BCELoss """ + loss = nn.BCELoss(reduction='mean') + + inputs_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]).astype(np.float32)) + target_data = Tensor(np.array([[0, 1, 0], [0, 0, 1]]).astype(np.float32)) + loss(inputs_data, target_data) + + +def test_BCELoss_weight(): + """ test_BCELoss """ + weight = Tensor(np.array([[1.0, 2.0, 3.0], [2.2, 2.6, 3.9]]).astype(np.float32)) + loss = nn.BCELoss(weight=weight) + + inputs_data = Tensor(np.array([[0.1, 0.2, 0.3], [0.5, 0.7, 0.9]]).astype(np.float32)) + target_data = Tensor(np.array([[0, 1, 0], [0, 0, 1]]).astype(np.float32)) + loss(inputs_data, target_data) + + def test_cosine_embedding_loss(): """ test CosineEmbeddingLoss """ loss = nn.CosineEmbeddingLoss()