diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index a415564b08..c1ccbccc0d 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -242,7 +242,6 @@ class MAELoss(_Loss): Raises: ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. - ValueError: If the dimensions are different. Supported Platforms: ``Ascend`` ``GPU`` @@ -257,7 +256,6 @@ class MAELoss(_Loss): """ def construct(self, logits, label): - _check_shape(logits.shape, label.shape) x = F.absolute(logits - label) return self.get_loss(x) @@ -439,7 +437,7 @@ class DiceLoss(_Loss): >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) >>> output = loss(y_pred, y) >>> print(output) - [0.7953220862819745] + [0.38596618] """ def __init__(self, smooth=1e-5): super(DiceLoss, self).__init__() @@ -453,7 +451,7 @@ class DiceLoss(_Loss): self.reduce_sum(self.mul(label.view(-1), label.view(-1))) single_dice_coeff = (2 * intersection) / (unionset + self.smooth) - dice_loss = 1 - single_dice_coeff / label.shape[0] + dice_loss = 1 - single_dice_coeff return dice_loss.mean() @@ -464,10 +462,17 @@ def _check_shape(logits_shape, label_shape): @constexpr -def _check_weights(weight, label): - if weight.shape[0] != label.shape[1]: - raise ValueError("The shape of weight should be equal to the shape of label, but the shape of weight is {}, " - "and the shape of label is {}.".format(weight.shape, label.shape)) +def _check_ndim_multi(logits_dim, label_dim): + if logits_dim < 2: + raise ValueError("Logits dimension should be greater than 1, but got {}".format(logits_dim)) + if label_dim < 2: + raise ValueError("label dimension should be greater than 1, but got {}".format(label_dim)) + + +@constexpr +def _check_weights(weight_shape, label_shape): + if weight_shape != label_shape: + raise ValueError("The weight shape[0] should be equal to label.shape[1].") class MultiClassDiceLoss(_Loss): @@ -480,13 +485,13 @@ class MultiClassDiceLoss(_Loss): weights (Union[Tensor, None]): Tensor of shape `[num_classes, dim]`. ignore_indiex (Union[int, None]): Class index to ignore. activation (Union[str, Cell]): Activate function applied to the output of the fully connected layer, eg. 'ReLU'. - Default: 'Softmax'. Choose from: - ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'FastGelu', 'Sigmoid', - 'PReLU', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', 'LogSigmoid'] + Default: 'softmax'. Choose from: ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh','Sigmoid'] Inputs: - - **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. - - **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32. + - **y_pred** (Tensor) - Tensor of shape (N, C, ...). y_pred dimension should be greater than 1. + The data type must be float16 or float32. + - **y** (Tensor) - Tensor of shape (N, C, ...). y dimension should be greater than 1. + The data type must be float16 or float32. Outputs: Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses. @@ -494,9 +499,12 @@ class MultiClassDiceLoss(_Loss): Raises: ValueError: If the shapes are different. TypeError: If the type of inputs are not Tensor. + ValueError: If the dimension of y or y_pred is less than 2. + ValueError: If the weight shape[0] is not equal to y.shape[1]. + ValueError: If weight is a tensor, but the dimension is not 2. Supported Platforms: - ``Ascend`` ``GPU`` ``CPU`` + ``Ascend`` ``GPU`` Examples: >>> loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax") @@ -504,22 +512,28 @@ class MultiClassDiceLoss(_Loss): >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32) >>> output = loss(y_pred, y) >>> print(output) - [0.7761003] + [0.3283009] """ def __init__(self, weights=None, ignore_indiex=None, activation="softmax"): super(MultiClassDiceLoss, self).__init__() + activation_list = ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'sigmoid'] self.binarydiceloss = DiceLoss(smooth=1e-5) self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor]) + if isinstance(self.weights, Tensor) and self.weights.ndim != 2: + raise ValueError("The weight dim should be 2, but got {}.".format(self.weights.ndim)) self.ignore_indiex = ignore_indiex if ignore_indiex is None else \ validator.check_value_type("ignore_indiex", ignore_indiex, [int]) + if isinstance(activation, str) and activation not in activation_list: + raise ValueError("The activation must be in {}, but got {}.".format(activation_list, activation)) + self.activation = get_activation(activation) if isinstance(activation, str) else activation if self.activation is not None and not isinstance(self.activation, Cell): - raise TypeError("The activation must be str or Cell, but got {}.".format(activation)) + raise TypeError("The activation must be str or Cell, but got {}.".format(type(self.activation))) self.reshape = P.Reshape() def construct(self, logits, label): - _check_shape(logits.shape, label.shape) + _check_ndim_multi(logits.ndim, label.ndim) total_loss = 0 if self.activation is not None: @@ -529,7 +543,7 @@ class MultiClassDiceLoss(_Loss): if i != self.ignore_indiex: dice_loss = self.binarydiceloss(logits[:, i], label[:, i]) if self.weights is not None: - _check_weights(self.weights, label) + _check_weights(self.weights.shape[0], label.shape[1]) dice_loss *= self.weights[i] total_loss += dice_loss @@ -991,7 +1005,9 @@ class BCEWithLogitsLoss(_Loss): @constexpr def _check_ndim(predict_nidm, target_ndim): - validator.check_int(predict_nidm, target_ndim, Rel.EQ, 'predict_nidm', 'target_ndim') + if predict_nidm != target_ndim: + raise ValueError("The dim of the predicted value and the dim of the target value must be equal, but got" + "predict dim {} and target dim {}.".format(predict_nidm, target_ndim)) @constexpr @@ -1003,7 +1019,7 @@ def _check_channel_and_shape(target, predict): @constexpr def _check_predict_channel(predict): if predict == 1: - raise NotImplementedError("Single channel prediction is not supported.") + raise ValueError("Single channel prediction is not supported.") class FocalLoss(_Loss): @@ -1032,10 +1048,13 @@ class FocalLoss(_Loss): Raises: TypeError: If the data type of ``gamma`` is not float.. TypeError: If ``weight`` is not a Parameter. - ValueError: If ``target`` shape different from ``predict``. + ValueError: If ``target`` dim different from ``predict``. ValueError: If ``target`` channel is not 1 and ``target`` shape is different from ``predict``. ValueError: If ``reduction`` is not one of 'none', 'mean', 'sum'. + Supported Platforms: + ``Ascend`` ``GPU`` + Example: >>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32) >>> target = Tensor([[1], [1], [0]], mstype.int32) diff --git a/mindspore/nn/metrics/bleu_score.py b/mindspore/nn/metrics/bleu_score.py index 58b8b381b8..5a4a73474f 100644 --- a/mindspore/nn/metrics/bleu_score.py +++ b/mindspore/nn/metrics/bleu_score.py @@ -84,8 +84,8 @@ class BleuScore(Metric): Args: inputs: Input `candidate_corpus` and ``reference_corpus`. `candidate_corpus` and `reference_corpus` are a - list. The `candidate_corpus` is an iterable of machine translated corpus. The `reference_corpus` is - an iterable of iterables of reference corpus. + list. The `candidate_corpus` is an iterable of machine translated corpus. The `reference_corpus` is + an iterable of iterables of reference corpus. Raises: ValueError: If the number of input is not 2. diff --git a/mindspore/nn/metrics/cosine_similarity.py b/mindspore/nn/metrics/cosine_similarity.py index 71af0ecfff..4e3c911cd8 100644 --- a/mindspore/nn/metrics/cosine_similarity.py +++ b/mindspore/nn/metrics/cosine_similarity.py @@ -62,7 +62,7 @@ class CosineSimilarity(Metric): Updates the internal evaluation result with 'input1'. Args: - inputs: input_data `input1`. The input_data is a `Tensor`or an array. + inputs: input_data `input1`. The input_data is a `Tensor` or an array. """ input_data = self._convert_data(inputs[0]) diff --git a/mindspore/nn/metrics/occlusion_sensitivity.py b/mindspore/nn/metrics/occlusion_sensitivity.py index e69e749bd1..be1a7ced4b 100644 --- a/mindspore/nn/metrics/occlusion_sensitivity.py +++ b/mindspore/nn/metrics/occlusion_sensitivity.py @@ -70,7 +70,7 @@ class OcclusionSensitivity(Metric): def __init__(self, pad_val=0.0, margin=2, n_batch=128, b_box=None): super().__init__() self.pad_val = validator.check_value_type("pad_val", pad_val, [float]) - self.margin = validator.check_value_type("margin", margin, [int, Sequence]) + self.margin = validator.check_value_type("margin", margin, [int, list]) self.n_batch = validator.check_value_type("n_batch", n_batch, [int]) self.b_box = b_box if b_box is None else validator.check_value_type("b_box", b_box, [list]) self.clear() diff --git a/mindspore/nn/metrics/roc.py b/mindspore/nn/metrics/roc.py index 89860401f2..2b871faf14 100644 --- a/mindspore/nn/metrics/roc.py +++ b/mindspore/nn/metrics/roc.py @@ -103,9 +103,9 @@ class ROC(Metric): Update state with predictions and targets. Args: inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray. - In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]` - and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` - is the number of categories. y contains values of integers. + In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]` + and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C` + is the number of categories. y contains values of integers. """ if len(inputs) != 2: raise ValueError('ROC need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) diff --git a/tests/ut/python/nn/test_loss.py b/tests/ut/python/nn/test_loss.py index 042019591a..0bf1e9ef6b 100644 --- a/tests/ut/python/nn/test_loss.py +++ b/tests/ut/python/nn/test_loss.py @@ -199,7 +199,7 @@ def test_multi_class_dice_loss_init_activation(): def test_multi_class_dice_loss_init_activation2(): """ test_multi_class_dice_loss """ - with pytest.raises(KeyError): + with pytest.raises(ValueError): loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation='www') y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32) y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)