|
|
|
@ -31,7 +31,7 @@ class ROC(Metric):
|
|
|
|
|
range [0,num_classes-1]. Default: None.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> 1) binary classification example
|
|
|
|
|
>>> # 1) binary classification example
|
|
|
|
|
>>> x = Tensor(np.array([3, 1, 4, 2]))
|
|
|
|
|
>>> y = Tensor(np.array([0, 1, 2, 3]))
|
|
|
|
|
>>> metric = ROC(pos_label=2)
|
|
|
|
@ -42,7 +42,7 @@ class ROC(Metric):
|
|
|
|
|
[0., 1, 1., 1., 1.]
|
|
|
|
|
[5, 4, 3, 2, 1]
|
|
|
|
|
>>>
|
|
|
|
|
>>> 2) multiclass classification example
|
|
|
|
|
>>> # 2) multiclass classification example
|
|
|
|
|
>>> x = Tensor(np.array([[0.28, 0.55, 0.15, 0.05], [0.10, 0.20, 0.05, 0.05], [0.20, 0.05, 0.15, 0.05],
|
|
|
|
|
... [0.05, 0.05, 0.05, 0.75]]))
|
|
|
|
|
>>> y = Tensor(np.array([0, 1, 2, 3]))
|
|
|
|
@ -101,6 +101,7 @@ class ROC(Metric):
|
|
|
|
|
def update(self, *inputs):
|
|
|
|
|
"""
|
|
|
|
|
Update state with predictions and targets.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
inputs: Input `y_pred` and `y`. `y_pred` and `y` are Tensor, list or numpy.ndarray.
|
|
|
|
|
In most cases (not strictly), y_pred is a list of floating numbers in range :math:`[0, 1]`
|
|
|
|
|