!13456 multdiceloss api

From: @lijiaqi0612
Reviewed-by: @zh_qh,@kisnwang
Signed-off-by: @kisnwang
pull/13456/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit fd811b4dd0

@ -490,9 +490,9 @@ class MultiClassDiceLoss(_Loss):
Inputs:
- **y_pred** (Tensor) - Tensor of shape (N, C, ...). The y_pred dimension should be greater than 1. The data
type must be float16 or float32.
type must be float16 or float32.
- **y** (Tensor) - Tensor of shape (N, C, ...). The y dimension should be greater than 1. The data type must be
float16 or float32.
loat16 or float32.
Outputs:
Tensor, a tensor of shape with the per-example sampled MultiClass Dice Losses.

@ -31,7 +31,7 @@ class ROC(Metric):
range [0,num_classes-1]. Default: None.
Examples:
>>> 1) binary classification example
>>> # 1) binary classification example
>>> x = Tensor(np.array([3, 1, 4, 2]))
>>> y = Tensor(np.array([0, 1, 2, 3]))
>>> metric = ROC(pos_label=2)
@ -42,7 +42,7 @@ class ROC(Metric):
[0., 1, 1., 1., 1.]
[5, 4, 3, 2, 1]
>>>
>>> 2) multiclass classification example
>>> # 2) multiclass classification example
>>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05],
... [0.05, 0.05, 0.05, 0.75]]))
>>> y = Tensor(np.array([0, 1, 2, 3]))
@ -101,6 +101,7 @@ class ROC(Metric):
def update(self, *inputs):
"""
Update state with predictions and targets.
Args:
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray.
In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]`

Loading…
Cancel
Save