diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 12e95cf781..10947a535a 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -237,3 +237,4 @@ from .basic_lstm_cell import _basic_lstm_cell_tbe from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe +from .confusion_matrix import _confusion_matrix_tbe diff --git a/mindspore/ops/_op_impl/tbe/confusion_matrix.py b/mindspore/ops/_op_impl/tbe/confusion_matrix.py new file mode 100644 index 0000000000..28dd17f23f --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/confusion_matrix.py @@ -0,0 +1,63 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ConfusionMatrix op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +confusion_matrix_op_info = TBERegOp("ConfusionMatrix") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("confusion_matrix.so") \ + .compute_cost(10) \ + .kernel_name("confusion_matrix") \ + .partial_flag(True) \ + .attr("num_classes", "required", "int", "all") \ + .attr("dtype", "required", "str", "all") \ + .input(0, "labels", False, "required", "all") \ + .input(1, "predictions", False, "required", "all") \ + .input(2, "weights", False, "optional", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(confusion_matrix_op_info) +def _confusion_matrix_tbe(): + """ConfusionMatrix TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 0b9df232ee..996df7c285 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -73,7 +73,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl, ApplyProximalAdagrad, SparseApplyProximalAdagrad, ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell) -from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey, CheckBprop +from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, + CheckValid, MakeRefKey, CheckBprop, ConfusionMatrix) from . import _quant_ops from ._quant_ops import * from .thor_ops import * @@ -287,7 +288,8 @@ __all__ = [ "BesselI1e", "Atan", "Atanh", - "BasicLSTMCell" + "BasicLSTMCell", + "ConfusionMatrix" ] __all__.extend(_quant_ops.__all__) diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index d73f53eb6a..714cad681f 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -366,3 +366,50 @@ class CheckBprop(PrimitiveWithInfer): raise TypeError(f"{tips}, the dtype of {i}th output should be {ydtype}," f" but got {xdtype}.") return xdtypes + + +class ConfusionMatrix(PrimitiveWithInfer): + r""" + Calculate the confusion matrix from labels and predictions. + + Args: + num_classes (int): The num of classes. + dtype (str): Data type of confusion matrix. Default: 'int32'. + + Inputs: + - **labels** (Tensor) - real labels, tensor of 1-D. the dtype must be non-negative Integer. + - **predictions** (Tensor) - the labels from prediction, tensor of 1-D. + the shape same as `labels` and the dtype must be non-negative Integer. + - **weights** (Tensor) - tensor of 1-D. the shape same as `predictions`. + + Outputs: + Tensor, the confusion matrix, with shape (`num_classes`, `num_classes`). + + Examples: + >>> confusion_matrix = P.ConfusionMatrix(4) + >>> labels = Tensor([0, 1, 1, 3], mindspore.int32) + >>> predictions = Tensor([1, 2, 1, 3], mindspore.int32) + >>> confusion_matrix(labels, predictions) + """ + + @prim_attr_register + def __init__(self, num_classes, dtype="int32"): + validator.check_value_type("num_classes", num_classes, [int], self.name) + validator.check_value_type("dtype", dtype, [str], self.name) + + def infer_shape(self, labels, predictions, weights=None): + validator.check('labels dimension', len(labels), '', 1, Rel.EQ, self.name) + validator.check('labels shape', labels, 'predictions shape', predictions, Rel.EQ, self.name) + if weights is not None: + validator.check('labels shape', labels, 'weights shape', weights, Rel.EQ, self.name) + ret = (self.num_classes, self.num_classes) + return ret + + def infer_dtype(self, labels, predictions, weights=None): + validator.check_subclass('labels', labels, mstype.tensor, self.name) + validator.check_subclass('predictions', predictions, mstype.tensor, self.name) + if weights is not None: + validator.check_subclass('weights', weights, mstype.tensor, self.name) + args = {"labels": labels, "predictions": predictions} + validator.check_tensor_type_same(args, (mstype.number_type), self.name) + return labels diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index bf1d8b72d3..de29220ea6 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -285,6 +285,16 @@ class SpaceToBatchNDNet(Cell): def construct(self, x): return self.space_to_batch_nd(x) + +class ConfusionMatrixNet(Cell): + def __init__(self): + super(ConfusionMatrixNet, self).__init__() + self.confusion_matrix = P.ConfusionMatrix(4, "int32") + + def construct(self, x, y): + return self.confusion_matrix(x, y) + + test_case_array_ops = [ ('CustNet1', { 'block': CustNet1(), @@ -325,6 +335,9 @@ test_case_array_ops = [ ('BatchToSpaceNDNet', { 'block': BatchToSpaceNDNet(), 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), + ('ConfusionMatrixNet', { + 'block': ConfusionMatrixNet(), + 'desc_inputs': [Tensor([0, 1, 1, 3], ms.int32), Tensor([0, 1, 1, 3], ms.int32)]}), ] test_case_lists = [test_case_array_ops]