From: @lijiaqi0612
Reviewed-by: 
Signed-off-by:
pull/12982/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit fe3b26a637

@ -242,7 +242,6 @@ class MAELoss(_Loss):
Raises:
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
ValueError: If the dimensions are different.
Supported Platforms:
``Ascend`` ``GPU``
@ -257,7 +256,6 @@ class MAELoss(_Loss):
"""
def construct(self, logits, label):
_check_shape(logits.shape, label.shape)
x = F.absolute(logits - label)
return self.get_loss(x)
@ -439,7 +437,7 @@ class DiceLoss(_Loss):
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
>>> output = loss(y_pred, y)
>>> print(output)
[0.7953220862819745]
[0.38596618]
"""
def __init__(self, smooth=1e-5):
super(DiceLoss, self).__init__()
@ -453,7 +451,7 @@ class DiceLoss(_Loss):
self.reduce_sum(self.mul(label.view(-1), label.view(-1)))
single_dice_coeff = (2 * intersection) / (unionset + self.smooth)
dice_loss = 1 - single_dice_coeff / label.shape[0]
dice_loss = 1 - single_dice_coeff
return dice_loss.mean()
@ -464,10 +462,17 @@ def _check_shape(logits_shape, label_shape):
@constexpr
def _check_weights(weight, label):
if weight.shape[0] != label.shape[1]:
raise ValueError("The shape of weight should be equal to the shape of label, but the shape of weight is {}, "
"and the shape of label is {}.".format(weight.shape, label.shape))
def _check_ndim_multi(logits_dim, label_dim):
if logits_dim < 2:
raise ValueError("Logits dimension should be greater than 1, but got {}".format(logits_dim))
if label_dim < 2:
raise ValueError("label dimension should be greater than 1, but got {}".format(label_dim))
@constexpr
def _check_weights(weight_shape, label_shape):
if weight_shape != label_shape:
raise ValueError("The weight shape[0] should be equal to label.shape[1].")
class MultiClassDiceLoss(_Loss):
@ -480,13 +485,13 @@ class MultiClassDiceLoss(_Loss):
weights (Union[Tensor, None]): Tensor of shape `[num_classes, dim]`.
ignore_indiex (Union[int, None]): Class index to ignore.
activation (Union[str, Cell]): Activate function applied to the output of the fully connected layer, eg. 'ReLU'.
Default: 'Softmax'. Choose from:
['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'FastGelu', 'Sigmoid',
'PReLU', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', 'LogSigmoid']
Default: 'softmax'. Choose from: ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh','Sigmoid']
Inputs:
- **y_pred** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
- **y** (Tensor) - Tensor of shape (N, ...). The data type must be float16 or float32.
- **y_pred** (Tensor) - Tensor of shape (N, C, ...). y_pred dimension should be greater than 1.
The data type must be float16 or float32.
- **y** (Tensor) - Tensor of shape (N, C, ...). y dimension should be greater than 1.
The data type must be float16 or float32.
Outputs:
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses.
@ -494,9 +499,12 @@ class MultiClassDiceLoss(_Loss):
Raises:
ValueError: If the shapes are different.
TypeError: If the type of inputs are not Tensor.
ValueError: If the dimension of y or y_pred is less than 2.
ValueError: If the weight shape[0] is not equal to y.shape[1].
ValueError: If weight is a tensor, but the dimension is not 2.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
``Ascend`` ``GPU``
Examples:
>>> loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation="softmax")
@ -504,22 +512,28 @@ class MultiClassDiceLoss(_Loss):
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
>>> output = loss(y_pred, y)
>>> print(output)
[0.7761003]
[0.3283009]
"""
def __init__(self, weights=None, ignore_indiex=None, activation="softmax"):
super(MultiClassDiceLoss, self).__init__()
activation_list = ['softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'sigmoid']
self.binarydiceloss = DiceLoss(smooth=1e-5)
self.weights = weights if weights is None else validator.check_value_type("weights", weights, [Tensor])
if isinstance(self.weights, Tensor) and self.weights.ndim != 2:
raise ValueError("The weight dim should be 2, but got {}.".format(self.weights.ndim))
self.ignore_indiex = ignore_indiex if ignore_indiex is None else \
validator.check_value_type("ignore_indiex", ignore_indiex, [int])
if isinstance(activation, str) and activation not in activation_list:
raise ValueError("The activation must be in {}, but got {}.".format(activation_list, activation))
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if self.activation is not None and not isinstance(self.activation, Cell):
raise TypeError("The activation must be str or Cell, but got {}.".format(activation))
raise TypeError("The activation must be str or Cell, but got {}.".format(type(self.activation)))
self.reshape = P.Reshape()
def construct(self, logits, label):
_check_shape(logits.shape, label.shape)
_check_ndim_multi(logits.ndim, label.ndim)
total_loss = 0
if self.activation is not None:
@ -529,7 +543,7 @@ class MultiClassDiceLoss(_Loss):
if i != self.ignore_indiex:
dice_loss = self.binarydiceloss(logits[:, i], label[:, i])
if self.weights is not None:
_check_weights(self.weights, label)
_check_weights(self.weights.shape[0], label.shape[1])
dice_loss *= self.weights[i]
total_loss += dice_loss
@ -991,7 +1005,9 @@ class BCEWithLogitsLoss(_Loss):
@constexpr
def _check_ndim(predict_nidm, target_ndim):
validator.check_int(predict_nidm, target_ndim, Rel.EQ, 'predict_nidm', 'target_ndim')
if predict_nidm != target_ndim:
raise ValueError("The dim of the predicted value and the dim of the target value must be equal, but got"
"predict dim {} and target dim {}.".format(predict_nidm, target_ndim))
@constexpr
@ -1003,7 +1019,7 @@ def _check_channel_and_shape(target, predict):
@constexpr
def _check_predict_channel(predict):
if predict == 1:
raise NotImplementedError("Single channel prediction is not supported.")
raise ValueError("Single channel prediction is not supported.")
class FocalLoss(_Loss):
@ -1032,10 +1048,13 @@ class FocalLoss(_Loss):
Raises:
TypeError: If the data type of ``gamma`` is not float..
TypeError: If ``weight`` is not a Parameter.
ValueError: If ``target`` shape different from ``predict``.
ValueError: If ``target`` dim different from ``predict``.
ValueError: If ``target`` channel is not 1 and ``target`` shape is different from ``predict``.
ValueError: If ``reduction`` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend`` ``GPU``
Example:
>>> predict = Tensor([[0.8, 1.4], [0.5, 0.9], [1.2, 0.9]], mstype.float32)
>>> target = Tensor([[1], [1], [0]], mstype.int32)

@ -84,8 +84,8 @@ class BleuScore(Metric):
Args:
inputs: Input `candidate_corpus` and ``reference_corpus`. `candidate_corpus` and `reference_corpus` are a
list. The `candidate_corpus` is an iterable of machine translated corpus. The `reference_corpus` is
an iterable of iterables of reference corpus.
list. The `candidate_corpus` is an iterable of machine translated corpus. The `reference_corpus` is
an iterable of iterables of reference corpus.
Raises:
ValueError: If the number of input is not 2.

@ -62,7 +62,7 @@ class CosineSimilarity(Metric):
Updates the internal evaluation result with 'input1'.
Args:
inputs: input_data `input1`. The input_data is a `Tensor`or an array.
inputs: input_data `input1`. The input_data is a `Tensor` or an array.
"""
input_data = self._convert_data(inputs[0])

@ -70,7 +70,7 @@ class OcclusionSensitivity(Metric):
def __init__(self, pad_val=0.0, margin=2, n_batch=128, b_box=None):
super().__init__()
self.pad_val = validator.check_value_type("pad_val", pad_val, [float])
self.margin = validator.check_value_type("margin", margin, [int, Sequence])
self.margin = validator.check_value_type("margin", margin, [int, list])
self.n_batch = validator.check_value_type("n_batch", n_batch, [int])
self.b_box = b_box if b_box is None else validator.check_value_type("b_box", b_box, [list])
self.clear()

@ -103,9 +103,9 @@ class ROC(Metric):
Update state with predictions and targets.
Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray.
In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]`
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. y contains values of integers.
In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]`
and the shape is :math:`(N, C)`, where :math:`N` is the number of cases and :math:`C`
is the number of categories. y contains values of integers.
"""
if len(inputs) != 2:
raise ValueError('ROC need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))

@ -199,7 +199,7 @@ def test_multi_class_dice_loss_init_activation():
def test_multi_class_dice_loss_init_activation2():
""" test_multi_class_dice_loss """
with pytest.raises(KeyError):
with pytest.raises(ValueError):
loss = nn.MultiClassDiceLoss(weights=None, ignore_indiex=None, activation='www')
y_pred = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]), mstype.float32)
y = Tensor(np.array([[1, 0], [0, 1]]), mstype.float32)

Loading…
Cancel
Save