diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc index f6e47640e2..6daef4c98b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc @@ -21,6 +21,7 @@ namespace mindspore { namespace kernel { template void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { + node_ = kernel_node; CheckParam(kernel_node); axis_ = LongToInt(AnfAlgo::GetNodeAttr(kernel_node, AXIS)); @@ -28,27 +29,28 @@ void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { if (axis_ < 0) { axis_ = axis_ + SizeToInt(input_1_shape.size()); } - - input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num_; i++) { - auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_); - input_flat_shape_list_.push_back(flat_shape); - } } template bool ConcatCPUKernel::Launch(const std::vector &inputs, const std::vector & /*workspace*/, const std::vector &outputs) { + size_t input_num = AnfAlgo::GetInputTensorNum(node_); + std::vector> input_flat_shape_list; + for (size_t i = 0; i < input_num; i++) { + auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(node_, i); + auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_); + input_flat_shape_list.push_back(flat_shape); + } + auto output_addr = reinterpret_cast(outputs[0]->addr); auto buff_size = outputs[0]->size; // each input's row of shape after flat are same - auto before_axis = input_flat_shape_list_[0][0]; + auto before_axis = input_flat_shape_list[0][0]; for (size_t i = 0; i < before_axis; ++i) { - for (size_t j = 0; j < input_num_; ++j) { + for (size_t j = 0; j < input_num; ++j) { auto input_j_addr = reinterpret_cast(inputs[j]->addr); - auto copy_num = input_flat_shape_list_[j][1]; + auto copy_num = input_flat_shape_list[j][1]; auto offset = copy_num * i; auto ret = memcpy_s(output_addr, buff_size, input_j_addr + offset, copy_num * sizeof(T)); if (ret != EOK) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h index 089463ee22..3f6887cdf7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h @@ -36,8 +36,7 @@ class ConcatCPUKernel : public CPUKernel { private: void CheckParam(const CNodePtr &kernel_node); int axis_ = 0; - size_t input_num_ = 1; - std::vector> input_flat_shape_list_; + CNodePtr node_ = nullptr; }; MS_REG_CPU_KERNEL_T( diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc index 88afa8c4f6..929455aa04 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc @@ -140,7 +140,7 @@ void OpTilingCalculater::Init() { tiling_func_map_ = optiling::OpTilingRegistryInterf::RegisteredOpInterf(); MS_LOG(INFO) << "tiling_func_map_ size:" << tiling_func_map_.size(); for (const auto &iter : tiling_func_map_) { - MS_LOG(INFO) << "Regist tiling func:" << iter.first; + MS_LOG(INFO) << "Register tiling func:" << iter.first; } } @@ -150,6 +150,7 @@ std::string GetRealOpType(const std::string &op_type) { {"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"}, {"SparseGatherV2", "GatherV2"}, {"Pad", "PadD"}, + {"Concat", "ConcatD"}, }; auto iter = kOpTypeMap.find(op_type); if (iter == kOpTypeMap.end()) { diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index ce73041d25..497d4c00d9 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -30,6 +30,7 @@ namespace mindspore { // op name. Op which not exists in operator/ops.h, so define it's name here +constexpr auto kConcatOpName = "Concat"; constexpr auto kUniqueOpName = "Unique"; constexpr auto kComputeAccidentalHitsOpName = "ComputeAccidentalHits"; constexpr auto kCTCGreedyDecoderOpName = "CTCGreedyDecoder"; @@ -492,7 +493,8 @@ const std::set kComputeDepend = {kUniqueOpName, kComputeAccidentalH const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; const std::set DynamicShapeConstInputToAttr = { - kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, kTransposeOpName, kReduceSumOpName}; + kCastOpName, kExpandDimsOpName, kReshapeOpName, kEmbeddingLookupOpName, + kTransposeOpName, kReduceSumOpName, kConcatOpName}; static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { try { diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index c183b78a38..5454813363 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -287,6 +287,8 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplAddN(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 521868e0e5..2756fc97f7 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -954,6 +954,90 @@ AbstractBasePtr InferImplSequenceMask(const AnalysisEnginePtr &, const Primitive return std::make_shared(kBool, output_shape); } +AbstractBasePtr InferImplConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + MS_EXCEPTION_IF_NULL(primitive); + const std::string op_name = primitive->name(); + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << "args_spec_list is empty."; + } + + AbstractTuplePtr arg = nullptr; + AbstractTensorPtr tensor_base = nullptr; + size_t tuple_len = 0; + MS_EXCEPTION_IF_NULL(args_spec_list[0]); + if (args_spec_list[0]->isa()) { + CheckArgsSize(op_name, args_spec_list, 1); + arg = CheckArg(op_name, args_spec_list, 0); + tuple_len = arg->elements().size(); + tensor_base = CheckArg(op_name, arg->elements(), 0); + } else if (args_spec_list[0]->isa()) { + tuple_len = args_spec_list.size(); + tensor_base = CheckArg(op_name, args_spec_list, 0); + } + + MS_EXCEPTION_IF_NULL(tensor_base); + ShapeVector shape_base = tensor_base->shape()->shape(); + int64_t rank_base = SizeToLong(shape_base.size()); + ShapeVector min_shape_base = tensor_base->shape()->min_shape(); + ShapeVector max_shape_base = tensor_base->shape()->max_shape(); + (void)CheckMinMaxShape(shape_base, &min_shape_base, &max_shape_base); + + primitive->set_attr("T", tensor_base->element()->BuildType()); + primitive->set_attr("inputNums", MakeValue(SizeToLong(tuple_len))); + + ValuePtr axis = primitive->GetAttr("axis"); + // Axis value should be in [-(rank_base + 1), rank_base). + int64_t axis_value = CheckAxis(op_name, axis, -(rank_base + 1), rank_base); + // If axis is negative, add offset(rank_base) to turn it to positive. + axis_value = GetPositiveAxis(axis_value, LongToSize(rank_base)); + + int64_t all_shp = shape_base[axis_value]; + int64_t min_all_shp = min_shape_base[axis_value]; + int64_t max_all_shp = max_shape_base[axis_value]; + for (size_t i = 1; i < tuple_len; ++i) { + AbstractTensorPtr tensor = nullptr; + if (args_spec_list[0]->isa()) { + tensor = CheckArg(op_name, arg->elements(), i); + } else if (args_spec_list[0]->isa()) { + tensor = CheckArg(op_name, args_spec_list, i); + } + ShapeVector shape_tensor = tensor->shape()->shape(); + int64_t rank_tensor = SizeToLong(shape_tensor.size()); + ShapeVector min_shape_tensor = tensor->shape()->min_shape(); + ShapeVector max_shape_tensor = tensor->shape()->max_shape(); + (void)CheckMinMaxShape(shape_tensor, &min_shape_tensor, &max_shape_tensor); + (void)CheckDtypeSame(op_name, tensor_base, tensor); + if (rank_tensor != rank_base) { + MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Rank"; + } + for (int j = 0; j < rank_base; ++j) { + if (j != axis_value && shape_tensor[j] != shape_base[j]) { + MS_LOG(EXCEPTION) << op_name << " can not concat element " << i << " with the first element: Wrong Size"; + } + } + if (all_shp == -1 || shape_base[axis_value] == -1) { + all_shp = -1; + } else { + all_shp += shape_tensor[axis_value]; + } + min_all_shp += min_shape_tensor[axis_value]; + max_all_shp += max_shape_tensor[axis_value]; + } + + AbstractTensorPtr ret = dyn_cast(tensor_base->Broaden()); + MS_EXCEPTION_IF_NULL(ret); + auto shape = ret->shape()->shape(); + auto min_shape = ret->shape()->min_shape(); + auto max_shape = ret->shape()->max_shape(); + (void)CheckMinMaxShape(shape, &min_shape, &max_shape); + shape[axis_value] = all_shp; + min_shape[axis_value] = min_all_shp; + max_shape[axis_value] = max_all_shp; + ret->set_shape(std::make_shared(shape, min_shape, max_shape)); + return ret; +} + AbstractBasePtr InferImplRange(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string &op_name = primitive->name(); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index b739f18eca..f098014d98 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -81,6 +81,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimMapUniform, {InferImplMapUniform, true}}, {prim::kPrimSplit, {InferImplSplit, true}}, {prim::kPrimSequenceMask, {InferImplSequenceMask, true}}, + {prim::kPrimConcat, {InferImplConcat, true}}, {prim::kPrimRange, {InferImplRange, true}}, // Structure {prim::kPrimMakeTuple, {InferImplMakeTuple, true}}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index f4a1811073..3554f711bf 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -170,6 +170,7 @@ from .minimum_ds import _minimum_ds_tbe from .minimum_grad import _minimum_grad_tbe from .maximum_grad import _maximum_grad_tbe from .concat import _concat_tbe +from .concat_ds import _concat_ds_tbe from .slice import _slice_tbe from .sign import _sign_tbe from .greater import _greater_tbe diff --git a/mindspore/ops/_op_impl/tbe/concat_ds.py b/mindspore/ops/_op_impl/tbe/concat_ds.py new file mode 100644 index 0000000000..037de95107 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/concat_ds.py @@ -0,0 +1,38 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Concat op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +concat_ds_op_info = TBERegOp("Concat") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("concat_d.so") \ + .compute_cost(10) \ + .kernel_name("concat_d") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .attr("axis", "required", "int", "all") \ + .input(0, "input_values", False, "dynamic", "all") \ + .output(0, "output_data", False, "required", "all") \ + .op_pattern("dynamicFormat") \ + .dtype_format(DataType.None_None, DataType.None_None) \ + .get_op_info() + + +@op_info_register(concat_ds_op_info) +def _concat_ds_tbe(): + """Concat TBE register""" + return diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index fa6366de62..2623daaed6 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2148,6 +2148,19 @@ class Concat(PrimitiveWithInfer): out = {'shape': ret_shp, 'dtype': x_type[0], 'value': value} + if -1 in x_shp[0]: + x_min_shp = input_x['min_shape'] + ret_min_shp = x_min_shp[0].copy() + ret_min_shp[axis] = 0 + for all_min_shp in x_min_shp: + ret_min_shp[axis] += all_min_shp[axis] + out['min_shape'] = ret_min_shp + x_max_shp = input_x['max_shape'] + ret_max_shp = x_max_shp[0].copy() + ret_max_shp[axis] = 0 + for all_max_shp in x_max_shp: + ret_max_shp[axis] += all_max_shp[axis] + out['max_shape'] = ret_max_shp return out @@ -2789,7 +2802,7 @@ class StridedSlice(PrimitiveWithInfer): if has_ellipsis: # When there is ellipsis, handle the second half of the ellipsis split. ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \ - len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len]))) + len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len]))) ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims]) j += 1 i += ellipsis_occupied_dims @@ -3985,7 +3998,7 @@ class SpaceToBatchND(PrimitiveWithInfer): offset = 1 for i in range(len(self.block_shape)): padded = out_shape[i + offset] + self.paddings[i][0] + \ - self.paddings[i][1] + self.paddings[i][1] if padded % self.block_shape[i] != 0: raise ValueError(f'For \'{self.name}\' padded[{i}] {padded} should be divisible by ' f'block_shape[{i}] {self.block_shape[i]}') diff --git a/tests/st/dynamic_shape/test_concat.py b/tests/st/dynamic_shape/test_concat.py new file mode 100644 index 0000000000..a6a4d32933 --- /dev/null +++ b/tests/st/dynamic_shape/test_concat.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(nn.Cell): + def __init__(self, axis=0): + super(Net, self).__init__() + self.unique = P.Unique() + self.reshape = P.Reshape() + self.concat = P.Concat(axis=axis) + + def construct(self, x1, x2): + out1_unique, _ = self.unique(x1) + out2_unique, _ = self.unique(x2) + out1_shape = self.reshape(out1_unique, (1, -1, 2)) + out2_shape = self.reshape(out2_unique, (1, -1, 2)) + return self.concat((out1_shape, out2_shape)) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_dynamic_concat(): + x1 = Tensor(np.array([1, 2, 3, 1, 4, 2]), mstype.int32) + x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.int32) + net = Net(axis=1) + output = net(x1, x2) + expect = np.array([[[1, 2], [3, 4], [1, 2], [3, 4], [5, 6]]]) + assert (output.asnumpy() == expect).all() diff --git a/tests/st/dynamic_shape/test_concat_cpu.py b/tests/st/dynamic_shape/test_concat_cpu.py new file mode 100644 index 0000000000..4ffc6cf61b --- /dev/null +++ b/tests/st/dynamic_shape/test_concat_cpu.py @@ -0,0 +1,49 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + +class Net(nn.Cell): + def __init__(self, axis=0): + super(Net, self).__init__() + self.unique = P.Unique() + self.reshape = P.Reshape() + self.concat = P.Concat(axis=axis) + + def construct(self, x1, x2): + out1_unique, _ = self.unique(x1) + out2_unique, _ = self.unique(x2) + out1_shape = self.reshape(out1_unique, (1, -1, 2)) + out2_shape = self.reshape(out2_unique, (1, -1, 2)) + return self.concat((out1_shape, out2_shape)) + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_dynamic_concat_cpu(): + x1 = Tensor(np.array([1, 2, 3, 1, 4, 2]), mstype.int32) + x2 = Tensor(np.array([1, 2, 3, 4, 5, 6]), mstype.int32) + net = Net(axis=1) + output = net(x1, x2) + expect = np.array([[[1, 2], [3, 4], [1, 2], [3, 4], [5, 6]]]) + assert (output.asnumpy() == expect).all() diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index c7ce5b1204..e446644bdd 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -835,14 +835,14 @@ def test_mixed_precision_cast(): assert z.dtype == mstype.float16 -def test_while_concat(): +def test_while_add(): class Net(nn.Cell): def __init__(self, data): super(Net, self).__init__() self.start = Tensor(0, dtype=mstype.int32) self.end = Tensor(2, dtype=mstype.int32) self.out = Tensor(np.zeros([2, 3], dtype=np.float32)) - self.concat = P.Concat() + self.add = P.TensorAdd() def construct(self, inputs): idx = self.start @@ -850,7 +850,7 @@ def test_while_concat(): out = self.out while idx < end: xi = inputs[idx, :, :] - out = self.concat((out, xi)) + out = self.add(out, xi) idx = idx + 1 return out