confusion matrix

pull/11278/head
Jiaqi 4 years ago
parent 659b5d8e10
commit c821a2f3a2

@ -36,6 +36,7 @@ from .bleu_score import BleuScore
from .cosine_similarity import CosineSimilarity from .cosine_similarity import CosineSimilarity
from .occlusion_sensitivity import OcclusionSensitivity from .occlusion_sensitivity import OcclusionSensitivity
from .perplexity import Perplexity from .perplexity import Perplexity
from .confusion_matrix import ConfusionMatrixMetric, ConfusionMatrix
__all__ = [ __all__ = [
"names", "names",
@ -61,6 +62,8 @@ __all__ = [
"MeanSurfaceDistance", "MeanSurfaceDistance",
"RootMeanSquareDistance", "RootMeanSquareDistance",
"Perplexity", "Perplexity",
"ConfusionMatrix",
"ConfusionMatrixMetric",
] ]
__factory__ = { __factory__ = {
@ -85,6 +88,8 @@ __factory__ = {
'mean_surface_distance': MeanSurfaceDistance, 'mean_surface_distance': MeanSurfaceDistance,
'root_mean_square_distance': RootMeanSquareDistance, 'root_mean_square_distance': RootMeanSquareDistance,
'perplexity': Perplexity, 'perplexity': Perplexity,
'confusion_matrix': ConfusionMatrix,
'confusion_matrix_metric': ConfusionMatrixMetric,
} }

File diff suppressed because it is too large Load Diff

@ -21,19 +21,19 @@ from .metric import Metric
class Dice(Metric): class Dice(Metric):
r""" r"""
The Dice coefficient is a set similarity metric. It is used to calculate the similarity between two samples. The The Dice coefficient is a set similarity metric. It is used to calculate the similarity between two samples. The
value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result
is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area. is the worst. The Dice coefficient indicates the ratio of the area between two objects to the total area.
The function is shown as follows: The function is shown as follows:
.. math:: .. math::
dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true} dice = \frac{2 * (pred \bigcap true)}{pred \bigcup true}
Args: Args:
smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0. smooth (float): A term added to the denominator to improve numerical stability. Should be greater than 0.
Default: 1e-5. Default: 1e-5.
threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5. threshold (float): A threshold, which is used to compare with the input tensor. Default: 0.5.
Examples: Examples:
>>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]])) >>> x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
>>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]])) >>> y = Tensor(np.array([[0, 1], [1, 0], [0, 1]]))
>>> metric = Dice(smooth=1e-5, threshold=0.5) >>> metric = Dice(smooth=1e-5, threshold=0.5)

@ -0,0 +1,73 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# """test_confusion_matrix"""
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.nn.metrics import ConfusionMatrix
def test_confusion_matrix():
"""test_confusion_matrix"""
x = Tensor(np.array([1, 0, 1, 0]))
y = Tensor(np.array([1, 0, 0, 1]))
metric = ConfusionMatrix(num_classes=2)
metric.clear()
metric.update(x, y)
output = metric.eval()
assert np.allclose(output, np.array([[1, 1], [1, 1]]))
def test_confusion_matrix_update_len():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
metric = ConfusionMatrix(num_classes=2)
metric.clear()
with pytest.raises(ValueError):
metric.update(x)
def test_confusion_matrix_update_dim():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
y = Tensor(np.array([1, 0]))
metric = ConfusionMatrix(num_classes=2)
metric.clear()
with pytest.raises(ValueError):
metric.update(x, y)
def test_confusion_matrix_init_num_classes():
with pytest.raises(TypeError):
ConfusionMatrix(num_classes='1')
def test_confusion_matrix_init_normalize_value():
with pytest.raises(ValueError):
ConfusionMatrix(num_classes=2, normalize="wwe")
def test_confusion_matrix_init_threshold():
with pytest.raises(TypeError):
ConfusionMatrix(num_classes=2, normalize='no_norm', threshold=1)
def test_confusion_matrix_runtime():
metric = ConfusionMatrix(num_classes=2)
metric.clear()
with pytest.raises(RuntimeError):
metric.eval()

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# """test_confusion_matrix_metric"""
import numpy as np
import pytest
from mindspore import Tensor
from mindspore.nn.metrics import ConfusionMatrixMetric
def test_confusion_matrix_metric():
"""test_confusion_matrix_metric"""
metric = ConfusionMatrixMetric(skip_channel=True, metric_name="tpr", calculation_method=False)
metric.clear()
x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
y = Tensor(np.array([[[0], [1]], [[0], [1]]]))
metric.update(x, y)
x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
y = Tensor(np.array([[[0], [1]], [[1], [0]]]))
metric.update(x, y)
output = metric.eval()
assert np.allclose(output, np.array([0.75]))
def test_confusion_matrix_metric_update_len():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
metric = ConfusionMatrixMetric(skip_channel=True, metric_name="ppv", calculation_method=True)
metric.clear()
with pytest.raises(ValueError):
metric.update(x)
def test_confusion_matrix_metric_update_dim():
x = Tensor(np.array([[0.2, 0.5, 0.7], [0.3, 0.1, 0.2], [0.9, 0.6, 0.5]]))
y = Tensor(np.array([1, 0]))
metric = ConfusionMatrixMetric(skip_channel=True, metric_name="tnr", calculation_method=True)
metric.clear()
with pytest.raises(ValueError):
metric.update(y, x)
def test_confusion_matrix_metric_init_skip_channel():
with pytest.raises(TypeError):
ConfusionMatrixMetric(skip_channel=1)
def test_confusion_matrix_metric_init_compute_sample():
with pytest.raises(TypeError):
ConfusionMatrixMetric(calculation_method=1)
def test_confusion_matrix_metric_init_metric_name_type():
with pytest.raises(TypeError):
metric = ConfusionMatrixMetric(skip_channel=True, metric_name=1, calculation_method=False)
x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
y = Tensor(np.array([[[0], [1]], [[1], [0]]]))
metric.update(x, y)
output = metric.eval()
assert np.allclose(output, np.array([0.75]))
def test_confusion_matrix_metric_init_metric_name_str():
with pytest.raises(NotImplementedError):
metric = ConfusionMatrixMetric(skip_channel=True, metric_name="wwwww", calculation_method=False)
x = Tensor(np.array([[[0], [1]], [[1], [0]]]))
y = Tensor(np.array([[[0], [1]], [[1], [0]]]))
metric.update(x, y)
output = metric.eval()
assert np.allclose(output, np.array([0.75]))
def test_confusion_matrix_metric_runtime():
metric = ConfusionMatrixMetric(skip_channel=True, metric_name="tnr", calculation_method=True)
metric.clear()
with pytest.raises(RuntimeError):
metric.eval()
Loading…
Cancel
Save