diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 110fb966bb..e466d478de 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -80,6 +80,8 @@ static std::map tbe_func_adapter_map = { {"concat", "concat_d"}, {"slice", "slice_d"}, {"reduce_sum", "reduce_sum_d"}, + {"inplace_add", "inplace_add_d"}, + {"inplace_sub", "inplace_sub_d"}, {"one_hot", "one_hot_d"}, {"sum", "reduce_sum_d"}, {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"}, diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index b41167274f..f5ec529877 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -171,6 +171,8 @@ const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); +const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); +const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); // NN const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 82c6f4f243..291c4a92d5 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -180,6 +180,8 @@ extern const PrimitivePtr kPrimLessEqual; extern const PrimitivePtr kPrimCumSum; extern const PrimitivePtr kPrimCumProd; extern const PrimitivePtr kPrimSubscalar; +extern const PrimitivePtr kPrimInplaceAdd; +extern const PrimitivePtr kPrimInplaceSub; // NN extern const PrimitivePtr kPrimFlatten; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 797c9c6b67..f337dca329 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -133,6 +133,8 @@ constexpr auto kResizeNearestNeighborV2OpName = "ResizeNearestNeighborV2"; constexpr auto kResizeNearestNeighborV2GradOpName = "ResizeNearestNeighborV2Grad"; constexpr auto kApplyRMSPropOpname = "ApplyRMSProp"; constexpr auto kCumsumOpName = "Cumsum"; +constexpr auto kInplaceAddOpName = "InplaceAdd"; +constexpr auto kInplaceSubOpName = "InplaceSub"; constexpr auto kResizeBilinearV2OpName = "kResizeBilinearV2"; constexpr auto kReduceProdOpName = "ReduceProd"; constexpr auto kCumprodOpName = "Cumprod"; diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index aaa810fcb0..1cfbffb934 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -15,6 +15,8 @@ """tbe ops""" from .abs import _abs_tbe +from .inplace_add import _inplace_add_tbe +from .inplace_sub import _inplace_sub_tbe from .abs_grad import _abs_grad_tbe from .acos import _acos_tbe from .acos_grad import _acos_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/inplace_add.py b/mindspore/ops/_op_impl/tbe/inplace_add.py new file mode 100644 index 0000000000..9a14fc9a63 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/inplace_add.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""InplaceAdd op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +inplace_add_op_info = TBERegOp("InplaceAdd") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("inplace_add_d.so") \ + .compute_cost(10) \ + .kernel_name("inplace_add_d") \ + .partial_flag(True) \ + .attr("indices", "required", "listInt", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "v", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(inplace_add_op_info) +def _inplace_add_tbe(): + """InplaceAdd TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/inplace_sub.py b/mindspore/ops/_op_impl/tbe/inplace_sub.py new file mode 100644 index 0000000000..07f59e05fc --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/inplace_sub.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""InplaceSub op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +inplace_sub_op_info = TBERegOp("InplaceSub") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("inplace_sub_d.so") \ + .compute_cost(10) \ + .kernel_name("inplace_sub_d") \ + .partial_flag(True) \ + .attr("indices", "required", "listInt", "all") \ + .input(0, "x", False, "required", "all") \ + .input(1, "v", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(inplace_sub_op_info) +def _inplace_sub_tbe(): + """InplaceSub TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 00c7db7cec..bc3301030b 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -41,7 +41,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge from .inner_ops import ScalarCast from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, BitwiseAnd, BitwiseOr, - BitwiseXor, Inv, Invert, ApproximateEqual, + BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil, Acosh, Greater, GreaterEqual, Less, LessEqual, Log, Log1p, LogicalAnd, @@ -178,6 +178,8 @@ __all__ = [ 'DropoutGrad', 'Dropout', 'Neg', + 'InplaceAdd', + 'InplaceSub', 'Slice', 'DType', 'NPUAllocFloatStatus', diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index d37e4d6331..5f25cdd8e4 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -772,6 +772,125 @@ class Neg(PrimitiveWithInfer): return input_x +class InplaceAdd(PrimitiveWithInfer): + """ + Adds v into specified rows of x. Computes y = x; y[i,] += v. + + Args: + - **indices** (Union[int, tuple]) - Indices into the left-most dimension of x, and determines which rows of x + to add with v. It is a int or tuple, whose value is in [0, the first dimension size of x). + + Inputs: + - **input_x** (Tensor) - The first input is a tensor whose data type is number. + - **input_v** (Tensor) - The second input is a tensor who has the same dimension sizes as x except + the first dimension, which must be the same as indices's size. + + Outputs: + Tensor, has the same shape and dtype as input. + + Examples: + + >>> indices = [0, 1] + >>> input_x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32) + >>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32) + >>> inplaceAdd = P.InplaceAdd(indices) + >>> inplaceAdd(input_x, input_v) + [[1.5 3.] + [4. 5.5] + [5. 6.]] + """ + + @prim_attr_register + def __init__(self, indices): + """init InplaceAdd""" + self.init_prim_io_names(inputs=['x', 'v'], outputs=['y']) + self.indices = indices + + def infer_shape(self, x_shape, v_shape): + validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name) + if isinstance(self.indices, int): + validator.check("size of indices", 1, "v's first dimension", v_shape[0], + Rel.EQ, self.name) + if self.indices < 0 or self.indices >= x_shape[0]: + raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.') + else: + validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0], + Rel.EQ, self.name) + for i in self.indices: + if i < 0 or i >= x_shape[0]: + raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') + if len(x_shape) > 1: + validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:], + Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, v_dtype): + args = {'x': x_dtype, 'v': v_dtype} + valid_type = [mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_type, self.name) + validator.check_value_type('indices', self.indices, [tuple, int], self.name) + return x_dtype + + +class InplaceSub(PrimitiveWithInfer): + """ + Subtracts v into specified rows of x. Computes y = x; y[i, :] -= v; return y. + + Args: + - **indices** (Union[int, tuple]) - Indices into the left-most dimension of x, and determines which rows of x + to sub with v. It is a int or tuple, whose value is in [0, the first dimension size of x). + + Inputs: + - **input_x** (Tensor) - The first input is a tensor whose data type is number. + - **input_v** (Tensor) - The second input is a tensor who has the same dimension sizes as x except + the first dimension, which must be the same as indices's size. + + Outputs: + Tensor, has the same shape and dtype as input. + + Examples: + >>> indices = [0, 1] + >>> input_x = Tensor(np.array([[1, 2], [3, 4], [5, 6]]), mindspore.float32) + >>> input_v = Tensor(np.array([[0.5, 1.0], [1.0, 1.5]]), mindspore.float32) + >>> inplaceSub = P.InplaceSub(indices) + >>> inplaceSub(input_x, input_v) + [[0.5 1.] + [2. 2.5] + [5. 6.]] + """ + + @prim_attr_register + def __init__(self, indices): + """init InplaceSub""" + self.init_prim_io_names(inputs=['x', 'v'], outputs=['y']) + self.indices = indices + + def infer_shape(self, x_shape, v_shape): + validator.check("x", len(x_shape), "v", len(v_shape), Rel.EQ, self.name) + if isinstance(self.indices, int): + validator.check("size of indices", 1, "v's first dimension", v_shape[0], + Rel.EQ, self.name) + if self.indices < 0 or self.indices >= x_shape[0]: + raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {self.indices}.') + else: + validator.check("size of indices", len(self.indices), "v's first dimension", v_shape[0], + Rel.EQ, self.name) + for i in self.indices: + if i < 0 or i >= x_shape[0]: + raise ValueError(f'The value of indices must be in [0, {x_shape[0]}), but got {i}.') + if len(x_shape) > 1: + validator.check("x's ith dimension", x_shape[1:], "v's ith dimension", v_shape[1:], + Rel.EQ, self.name) + return x_shape + + def infer_dtype(self, x_dtype, v_dtype): + args = {'x': x_dtype, 'v': v_dtype} + valid_type = [mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_type, self.name) + validator.check_value_type('indices', self.indices, [tuple, int], self.name) + return x_dtype + + class Sub(_MathBinaryOp): """ Subtracts the second input tensor from the first input tensor element-wise. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 500bafe9ff..79d7de5d7d 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -368,6 +368,26 @@ class ApplyRMSNet(nn.Cell): return out +class InplaceAddNet(nn.Cell): + def __init__(self): + super(InplaceAddNet, self).__init__() + self.inplace_add = P.InplaceAdd(indices=(0, 1)) + + def construct(self, x, v): + out = self.inplace_add(x, v) + return out + + +class InplaceSubNet(nn.Cell): + def __init__(self): + super(InplaceSubNet, self).__init__() + self.inplace_sub = P.InplaceSub(indices=(0, 1)) + + def construct(self, x, v): + out = self.inplace_sub(x, v) + return out + + test_case_math_ops = [ ('BitwiseAnd', { 'block': P.BitwiseAnd(), @@ -493,6 +513,16 @@ test_case_math_ops = [ 'desc_inputs': [[2, 512, 56, 56]], 'desc_bprop': [[2, 512, 56, 56]], 'skip': ['backward']}), + ('InplaceAdd', { + 'block': InplaceAddNet(), + 'desc_inputs': [Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)), + Tensor(np.array([[0.5, 1], [1, 1.5]]).astype(np.float32))], + 'skip': ['backward']}), + ('InplaceSub', { + 'block': InplaceSubNet(), + 'desc_inputs': [Tensor(np.array([[1, 2], [3, 4], [5, 6]]).astype(np.float32)), + Tensor(np.array([[0.5, 1], [1, 1.5]]).astype(np.float32))], + 'skip': ['backward']}), ('ACos', { 'block': P.ACos(), 'desc_inputs': [Tensor(np.array([2., 3.]).astype(np.float32))],