From 0660140708e9aec125fbe55f9076367da550c24f Mon Sep 17 00:00:00 2001 From: hedongdong Date: Wed, 10 Feb 2021 14:58:39 +0800 Subject: [PATCH] add new inner operator centralizaiton --- .../pass/const_input_to_attr_registry.cc | 1 + mindspore/core/base/core_ops.h | 1 + mindspore/ops/_op_impl/tbe/__init__.py | 1 + mindspore/ops/_op_impl/tbe/centralization.py | 38 +++++++++++ mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/inner_ops.py | 68 +++++++++++++++++++ .../test_tbe_ops/test_centralization.py | 47 +++++++++++++ 7 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/tbe/centralization.py create mode 100644 tests/st/ops/ascend/test_tbe_ops/test_centralization.py diff --git a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc index dd7da1d3b6..66d510b1b9 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/const_input_to_attr_registry.cc @@ -38,6 +38,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(prim::kPrimReduceMin->name(), {1}); Register(prim::kPrimReduceSum->name(), {1}); Register(prim::kPrimReduceMean->name(), {1}); + Register(prim::kPrimCentralization->name(), {1}); Register(prim::kPrimGather->name(), {2}); Register(prim::kPrimGatherD->name(), {1}); Register(prim::kPrimEmbeddingLookup->name(), {2, 3, 4, 5}); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index d8f6c44a64..f827faf4e8 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -350,6 +350,7 @@ inline const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAl inline const PrimitivePtr kPrimReduceAny = std::make_shared("ReduceAny"); inline const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); inline const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); +inline const PrimitivePtr kPrimCentralization = std::make_shared("Centralization"); inline const PrimitivePtr kPrimNeg = std::make_shared("Neg"); inline const PrimitivePtr kPrimSin = std::make_shared("Sin"); inline const PrimitivePtr kPrimCos = std::make_shared("Cos"); diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 1e4b1baddb..18eb98ea84 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -360,3 +360,4 @@ from .nll_loss_grad import _nll_loss_grad_tbe from .mish import _mish_tbe from .mul_no_nan import _mul_no_nan_tbe from .selu import _selu_tbe +from .centralization import _centralization_tbe diff --git a/mindspore/ops/_op_impl/tbe/centralization.py b/mindspore/ops/_op_impl/tbe/centralization.py new file mode 100644 index 0000000000..a06a4bab0b --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/centralization.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Centralization op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +centralization_op_info = TBERegOp("Centralization") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("centralization.so") \ + .compute_cost(10) \ + .kernel_name("centralization") \ + .partial_flag(True) \ + .attr("axis", "required", "listInt", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("reduce") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(centralization_op_info) +def _centralization_tbe(): + """Centralization TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 4936cf7f62..20bd7cfd52 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -41,7 +41,7 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) from .control_ops import ControlDepend, GeSwitch, Merge -from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey +from .inner_ops import ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey, Centralization from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index efd5ef3ff1..a90bfe88a1 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -21,6 +21,7 @@ from ..._checkparam import Rel from ...common import dtype as mstype from ...common.dtype import tensor, dtype_to_pytype from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer +from .. import signature as sig class ScalarCast(PrimitiveWithInfer): @@ -357,3 +358,70 @@ class MakeRefKey(Primitive): def __call__(self): pass + + +class Centralization(PrimitiveWithInfer): + """ + Computes centralization. y = x - mean(x, axis). + + Note: + The dimension index starts at 0 and must be in the range `[-input.ndim, input.ndim)`. + + Inputs: + - **input_x** (Tensor) - The input tensor. The data type mast be float16 or float32. + - **axis** (Union[Int, Tuple(Int), List(Int)]) - The dimensions to reduce. Default: (), reduce all dimensions. + Only constant value is allowed. Must be in the range [-rank(input_x), rank(input_x)). + + Outputs: + Tensor, has the same shape and dtype as the `input_x`. + + Raises: + TypeError: If `axis` is not one of the following types: int, list, tuple, NoneType. + TypeError: If `axis` has non-Int elements. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> mindspore.set_seed(1) + >>> input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) + >>> centralization = ops.Centralization() + >>> output = centralization(input_x, -1) + >>> print(output) + [[ 1.1180509 -1.1180508] + [ 0.2723984 -0.2723984]] + """ + + __mindspore_signature__ = ( + sig.make_sig('input_x'), + sig.make_sig('axis', default=()) + ) + + @prim_attr_register + def __init__(self): + """Initialize Centralization""" + self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['output']) + + def __infer__(self, input_x, axis): + x_shape = list(input_x['shape']) + x_dtype = input_x['dtype'] + axis_v = axis['value'] + rank = len(x_shape) + + args = {'input_x': input_x['dtype']} + validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) + + if axis_v is None: + raise ValueError(f"For {self.name}, axis must be const.") + validator.check_value_type('axis', axis_v, [int, list, tuple], self.name) + + if isinstance(axis_v, int): + validator.check_int_range(axis_v, -rank, rank, Rel.INC_LEFT, 'axis', self.name) + elif axis: + for index, one_axis in enumerate(axis_v): + validator.check_value_type('axis[%d]' % index, one_axis, [int], self.name) + + out = {'shape': x_shape, + 'dtype': x_dtype, + 'value': None} + return out diff --git a/tests/st/ops/ascend/test_tbe_ops/test_centralization.py b/tests/st/ops/ascend/test_tbe_ops/test_centralization.py new file mode 100644 index 0000000000..9012fa2910 --- /dev/null +++ b/tests/st/ops/ascend/test_tbe_ops/test_centralization.py @@ -0,0 +1,47 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common.api import ms_function +from mindspore.ops import operations as P + +class Net(nn.Cell): + def __init__(self, axis=()): + super(Net, self).__init__() + self.centralization = P.Centralization() + self.axis = axis + + @ms_function + def construct(self, inputs): + return self.centralization(inputs, self.axis) + +def test_net(): + np.random.seed(1) + x1 = np.random.randn(2, 2).astype(np.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + centralization = Net(-1) + output = centralization(Tensor(x1)) + print(x1) + print(output.asnumpy()) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + centralization = Net(-1) + output = centralization(Tensor(x1)) + print(x1) + print(output.asnumpy())