diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 76cea197ba..cb4c96ecc4 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -285,3 +285,4 @@ from .mod import _mod_tbe from .max_pool_grad_grad import _max_pool_grad_grad_tbe from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe from .population_count import _population_count_tbe +from .parallel_concat import _parallel_concat_tbe diff --git a/mindspore/ops/_op_impl/tbe/parallel_concat.py b/mindspore/ops/_op_impl/tbe/parallel_concat.py new file mode 100644 index 0000000000..46d8736fab --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/parallel_concat.py @@ -0,0 +1,80 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ParallelConcat op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +parallel_concat_op_info = TBERegOp("ParallelConcat") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("parallel_concat.so") \ + .compute_cost(10) \ + .kernel_name("parallel_concat") \ + .partial_flag(True) \ + .attr("shape", "required", "listInt", "all") \ + .attr("N", "required", "int", "all") \ + .input(0, "values", False, "dynamic", "all") \ + .output(0, "output_data", False, "required", "all") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.I16_Default, DataType.I16_Default) \ + .dtype_format(DataType.I16_5HD, DataType.I16_5HD) \ + .dtype_format(DataType.U16_Default, DataType.U16_Default) \ + .dtype_format(DataType.U16_5HD, DataType.U16_5HD) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.U32_Default, DataType.U32_Default) \ + .dtype_format(DataType.U32_5HD, DataType.U32_5HD) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default) \ + .dtype_format(DataType.I64_5HD, DataType.I64_5HD) \ + .dtype_format(DataType.U64_Default, DataType.U64_Default) \ + .dtype_format(DataType.U64_5HD, DataType.U64_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.BOOL_NHWC, DataType.BOOL_NHWC) \ + .dtype_format(DataType.BOOL_NCHW, DataType.BOOL_NCHW) \ + .dtype_format(DataType.I8_NHWC, DataType.I8_NHWC) \ + .dtype_format(DataType.I8_NCHW, DataType.I8_NCHW) \ + .dtype_format(DataType.U8_NHWC, DataType.U8_NHWC) \ + .dtype_format(DataType.U8_NCHW, DataType.U8_NCHW) \ + .dtype_format(DataType.I16_NHWC, DataType.I16_NHWC) \ + .dtype_format(DataType.I16_NCHW, DataType.I16_NCHW) \ + .dtype_format(DataType.U16_NHWC, DataType.U16_NHWC) \ + .dtype_format(DataType.U16_NCHW, DataType.U16_NCHW) \ + .dtype_format(DataType.I32_NHWC, DataType.I32_NHWC) \ + .dtype_format(DataType.I32_NCHW, DataType.I32_NCHW) \ + .dtype_format(DataType.U32_NHWC, DataType.U32_NHWC) \ + .dtype_format(DataType.U32_NCHW, DataType.U32_NCHW) \ + .dtype_format(DataType.I64_NHWC, DataType.I64_NHWC) \ + .dtype_format(DataType.I64_NCHW, DataType.I64_NCHW) \ + .dtype_format(DataType.U64_NHWC, DataType.U64_NHWC) \ + .dtype_format(DataType.U64_NCHW, DataType.U64_NCHW) \ + .dtype_format(DataType.F16_NHWC, DataType.F16_NHWC) \ + .dtype_format(DataType.F16_NCHW, DataType.F16_NCHW) \ + .dtype_format(DataType.F32_NHWC, DataType.F32_NHWC) \ + .dtype_format(DataType.F32_NCHW, DataType.F32_NCHW) \ + .get_op_info() + + +@op_info_register(parallel_concat_op_info) +def _parallel_concat_tbe(): + """ParallelConcat TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index fe224e8850..8564a7e035 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -28,6 +28,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin, ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, Shape, Size, Slice, Split, TransShape, + ParallelConcat, Squeeze, StridedSlice, Tile, TensorScatterUpdate, Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace, @@ -329,7 +330,8 @@ __all__ = [ "InTopK", "LRN", "Mod", - "PopulationCount" + "PopulationCount", + "ParallelConcat", ] __all__.sort() diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b30a03d604..4362d80abb 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1463,6 +1463,57 @@ class Concat(PrimitiveWithInfer): return out +class ParallelConcat(PrimitiveWithInfer): + r""" + Concat tensor in the first dimension. + + Concat input tensors along with the first dimension. + + Note: + The input tensors are all required to have size 1 in the first dimension. + + Inputs: + - **values** (tuple, list) - Tuple or list of input tensors. + + Outputs: + Tensor, data type same as `values`. + + Examples: + >>> data1 = Tensor(np.array([[0, 1]]).astype(np.int32)) + >>> data2 = Tensor(np.array([[2, 1]]).astype(np.int32)) + >>> op = P.ParallelConcat() + >>> output = op((data1, data2)) + """ + + @prim_attr_register + def __init__(self): + """init ParallelConcat""" + + def __infer__(self, values): + x_shp = values['shape'] + x_type = values['dtype'] + + validator.check_integer(f'x_shp length', len(x_shp), 1, Rel.GE, self.name) + first_elem = x_shp[0] + args = {} + for i, elem in enumerate(x_shp[1:]): + j = i + 1 + args[f'x_type[{j}]'] = x_type[j] + validator.check_integer(f'x_shp[{j}][0]', elem[0], 1, Rel.EQ, self.name) + validator.check(f"x_shp[0] shape", first_elem, f"x_shp[{j}] shape", elem, Rel.EQ, self.name) + validator.check_tensor_type_same(args, mstype.number_type + (mstype.bool_,), self.name) + + ret_shp = x_shp[0].copy() + ret_shp[0] = len(x_shp) + self.add_prim_attr('shape', ret_shp) + self.add_prim_attr('N', len(x_shp)) + + out = {'shape': ret_shp, + 'dtype': x_type[0], + 'value': None} + return out + + def _get_pack_shape(x_shape, x_type, axis, prim_name): """for pack output shape""" validator.check_value_type("shape", x_shape, [tuple, list], prim_name) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index fa79275ce3..8093ab82d5 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -596,6 +596,15 @@ def test_strided_slice_const(): assert (ret.asnumpy() == np.array([], np.float32).reshape([0, 1, 7, 8, 9, 3, 1])).all() +class ParallelConcatNet(nn.Cell): + def __init__(self): + super(ParallelConcatNet, self).__init__() + self.parallel_concat = P.ParallelConcat() + + def construct(self, x1, x2): + return self.parallel_concat((x1, x2)) + + test_case_math_ops = [ ('BitwiseAnd', { 'block': P.BitwiseAnd(), @@ -1875,6 +1884,12 @@ test_case_array_ops = [ 'desc_inputs': [[1, 3, 24, 24]], 'desc_bprop': [[1, 12, 24, 24]], }), + ('ParallelConcat', { + 'block': ParallelConcatNet(), + 'desc_inputs': [Tensor([[1, 2]], mstype.float32), + Tensor([[5, 6]], mstype.float32)], + 'skip': ['backward'], + }), ] test_case_other_ops = [