!14094 Add validation and modify the example code in the comment. The code and output result are wrong.

From: @lijiaqi0612
Reviewed-by: 
Signed-off-by:
pull/14094/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d52da91124

@ -436,7 +436,7 @@ class DiceLoss(_Loss):
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]), mstype.float32)
>>> output = loss(y_pred, y)
>>> print(output)
0.38596618
[0.38596618]
"""
def __init__(self, smooth=1e-5):
super(DiceLoss, self).__init__()
@ -1027,6 +1027,12 @@ def _check_channel_and_shape(predict, target):
f"inferred from 'predict': C={predict}.")
@constexpr
def _check_input_dtype(targets_dtype, cls_name):
validator.check_type_name("targets", targets_dtype, [mstype.int32, mstype.int64, mstype.float16,
mstype.float32], cls_name)
class FocalLoss(_Loss):
r"""
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
@ -1089,11 +1095,14 @@ class FocalLoss(_Loss):
self.squeeze = P.Squeeze(axis=1)
self.tile = P.Tile()
self.cast = P.Cast()
self.dtype = P.DType()
self.logsoftmax = nn.LogSoftmax(1)
def construct(self, predict, target):
targets = target
_check_ndim(predict.ndim, targets.ndim)
_check_channel_and_shape(predict.shape[1], targets.shape[1])
_check_input_dtype(self.dtype(targets), self.cls_name)
if predict.ndim > 2:
predict = predict.view(predict.shape[0], predict.shape[1], -1)
@ -1102,7 +1111,7 @@ class FocalLoss(_Loss):
predict = self.expand_dims(predict, 2)
targets = self.expand_dims(targets, 2)
log_probability = nn.LogSoftmax(1)(predict)
log_probability = self.logsoftmax(predict)
if target.shape[1] == 1:
log_probability = self.gather_d(log_probability, 1, self.cast(targets, mindspore.int32))
@ -1116,7 +1125,7 @@ class FocalLoss(_Loss):
if target.shape[1] == 1:
convert_weight = self.gather_d(convert_weight, 1, self.cast(targets, mindspore.int32))
convert_weight = self.squeeze(convert_weight)
probability = log_probability * convert_weight
log_probability = log_probability * convert_weight
weight = F.pows(-probability + 1.0, self.gamma)
if target.shape[1] == 1:

@ -24,8 +24,8 @@ class BleuScore(Metric):
Calculates BLEU score of machine translated text with one or more references.
Args:
n_gram (int): The n_gram value ranged from 1 to 4. Default: 4
smooth (bool): Whether or not to apply smoothing. Default: False
n_gram (int): The n_gram value ranged from 1 to 4. Default: 4.
smooth (bool): Whether or not to apply smoothing. Default: False.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
@ -33,7 +33,7 @@ class BleuScore(Metric):
Example:
>>> candidate_corpus = [['i', 'have', 'a', 'pen', 'on', 'my', 'desk']]
>>> reference_corpus = [[['i', 'have', 'a', 'pen', 'in', 'my', 'desk'],
>>> ['there', 'is', 'a', 'pen', 'on', 'the', 'desk']]]
... ['there', 'is', 'a', 'pen', 'on', 'the', 'desk']]]
>>> metric = BleuScore()
>>> metric.clear()
>>> metric.update(candidate_corpus, reference_corpus)

@ -179,7 +179,7 @@ class ConfusionMatrixMetric(Metric):
>>> y = Tensor(np.array([[[0], [1]], [[0], [1]]]))
>>> metric.update(x, y)
>>> x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
>>> y = Tensor(np.array([[[0], [1]], [[1], [1]]]))
>>> y = Tensor(np.array([[[0], [1]], [[1], [0]]]))
>>> avg_output = metric.eval()
>>> print(avg_output)
[0.75]

@ -51,7 +51,7 @@ class OcclusionSensitivity(Metric):
Example:
>>> class DenseNet(nn.Cell):
... def init(self):
... def __init__(self):
... super(DenseNet, self).init()
... w = np.array([[0.1, 0.8, 0.1, 0.1],[1, 1, 1, 1]]).astype(np.float32)
... b = np.array([0.3, 0.6]).astype(np.float32)

@ -42,11 +42,11 @@ class ROC(Metric):
>>> metric.update(x, y)
>>> fpr, tpr, thresholds = metric.eval()
>>> print(fpr)
[0., 0., 0.33333333, 0.6666667, 1.]
[0. 0. 0.33333333 0.6666667 1.]
>>> print(tpr)
[0., 1, 1., 1., 1.]
[0. 1. 1. 1. 1.]
>>> print(thresholds)
[5, 4, 3, 2, 1]
[5 4 3 2 1]
>>>
>>> # 2) multiclass classification example
>>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05],

@ -183,7 +183,7 @@ class DatasetHelper:
>>> network = Net()
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
>>> network = nn.WithLossCell(network, net_loss)
>>> train_dataset = create_custom_dataset(sparse=True)
>>> train_dataset = create_custom_dataset()
>>> dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False)
>>> for next_element in dataset_helper:
... outputs = network(*next_element)

Loading…
Cancel
Save