diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc index bf7179b551..f6e47640e2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.cc @@ -19,84 +19,50 @@ namespace mindspore { namespace kernel { -void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { +template +void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) { CheckParam(kernel_node); - axis_ = AnfAlgo::GetNodeAttr(kernel_node, AXIS); + axis_ = LongToInt(AnfAlgo::GetNodeAttr(kernel_node, AXIS)); auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); if (axis_ < 0) { - axis_ = axis_ + SizeToLong(input_1_shape.size()); + axis_ = axis_ + SizeToInt(input_1_shape.size()); } - axis_ += 4 - SizeToLong(input_1_shape.size()); - auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t i = 0; i < input_num; i++) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); - CPUKernelUtils::ExpandDimsTo4(&input_shape); - input_shape_list_.push_back(input_shape); + input_num_ = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num_; i++) { + auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_); + input_flat_shape_list_.push_back(flat_shape); } - - output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); - CPUKernelUtils::ExpandDimsTo4(&output_shape_); } -bool ConcatCPUKernel::Launch(const std::vector &inputs, - const std::vector & /*workspace*/, - const std::vector &outputs) { - auto output_addr = reinterpret_cast(outputs[0]->addr); +template +bool ConcatCPUKernel::Launch(const std::vector &inputs, + const std::vector & /*workspace*/, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); auto buff_size = outputs[0]->size; - size_t dim0 = output_shape_[0]; - size_t dim1 = output_shape_[1]; - size_t dim2 = output_shape_[2]; - - if (axis_ == 3) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - for (size_t k = 0; k < dim2; ++k) { - CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size); - } - } - } - } else if (axis_ == 2) { - for (size_t i = 0; i < dim0; ++i) { - for (size_t j = 0; j < dim1; ++j) { - CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size); + // each input's row of shape after flat are same + auto before_axis = input_flat_shape_list_[0][0]; + for (size_t i = 0; i < before_axis; ++i) { + for (size_t j = 0; j < input_num_; ++j) { + auto input_j_addr = reinterpret_cast(inputs[j]->addr); + auto copy_num = input_flat_shape_list_[j][1]; + auto offset = copy_num * i; + auto ret = memcpy_s(output_addr, buff_size, input_j_addr + offset, copy_num * sizeof(T)); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "memcpy failed."; } + output_addr += copy_num; + buff_size -= copy_num * sizeof(T); } - } else if (axis_ == 1) { - for (size_t i = 0; i < dim0; ++i) { - CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size); - } - } else if (axis_ == 0) { - CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); } return true; } -void ConcatCPUKernel::CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, - size_t dim2, float **output_addr, size_t *buff_size) { - for (size_t i = 0; i < input_shape_list_.size(); ++i) { - auto input_i_shape = input_shape_list_[i]; - auto input_i_addr = reinterpret_cast(inputs[i]->addr); - - size_t num = CPUKernelUtils::GetElementNumOnAxis(input_i_shape, axis_); - num *= input_i_shape[axis_]; - auto pos = CPUKernelUtils::CalcOffset(input_i_shape, dim0, dim1, dim2, 0); - auto ret = memcpy_s(*output_addr, *buff_size, input_i_addr + pos, num * sizeof(float)); - if (ret != EOK) { - MS_LOG(EXCEPTION) << "memcpy failed."; - } - *output_addr += num; - *buff_size -= num * sizeof(float); - } -} - -void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) { - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - if (input_shape.size() > 4) { - MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but ConcatCPUKernel olny support 4d or lower."; - } - +template +void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) { size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output."; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h index 90f712295f..383d6d41f3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/concat_cpu_kernel.h @@ -22,9 +22,10 @@ namespace mindspore { namespace kernel { +template class ConcatCPUKernel : public CPUKernel { public: - ConcatCPUKernel() : axis_(0) {} + ConcatCPUKernel() = default; ~ConcatCPUKernel() override = default; void InitKernel(const CNodePtr &kernel_node) override; @@ -34,16 +35,20 @@ class ConcatCPUKernel : public CPUKernel { private: void CheckParam(const CNodePtr &kernel_node); - void CopyDataToOutput(const std::vector &inputs, size_t dim0, size_t dim1, size_t dim2, - float **output_addr, size_t *buff_size); - int64_t axis_; - std::vector> input_shape_list_; - std::vector output_shape_; + int axis_ = 0; + size_t input_num_ = 1; + std::vector> input_flat_shape_list_; }; -MS_REG_CPU_KERNEL(Concat, - KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - ConcatCPUKernel); +MS_REG_CPU_KERNEL_T( + Concat, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ConcatCPUKernel, float); +MS_REG_CPU_KERNEL_T(Concat, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ConcatCPUKernel, int) +MS_REG_CPU_KERNEL_T(Concat, + KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), + ConcatCPUKernel, bool) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc index 778bfc2d35..baf1c9e30a 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.cc @@ -98,5 +98,24 @@ void CPUKernelUtils::ParallelFor(const CTask &task, size_t count) { } } +std::vector CPUKernelUtils::FlatShapeByAxis(const std::vector &shape, int axis) { + if (axis < 0) { + axis = axis + SizeToInt(shape.size()); + } + size_t dim_row = 1; + size_t dim_col = 1; + std::vector flat_shape; + for (size_t i = 0; i < shape.size(); ++i) { + if (SizeToInt(i) < axis) { + dim_row *= shape[i]; + } else { + dim_col *= shape[i]; + } + } + flat_shape.push_back(dim_row); + flat_shape.push_back(dim_col); + return flat_shape; +} + } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index b097e6f40c..3f577f2f39 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -120,6 +120,7 @@ class CPUKernelUtils { static size_t GetElementNumOnAxis(const std::vector &shape, int axis); static void GetElementNumEveryDim(const std::vector &shape, std::vector *element_num); static void ParallelFor(const CTask &task, size_t count); + static std::vector FlatShapeByAxis(const std::vector &shape, int axis); }; } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/cpu/test_concat_op.py b/tests/st/ops/cpu/test_concat_op.py index c2a1d07853..0040aef64b 100644 --- a/tests/st/ops/cpu/test_concat_op.py +++ b/tests/st/ops/cpu/test_concat_op.py @@ -19,84 +19,290 @@ from mindspore import Tensor from mindspore.ops import operations as P import mindspore.nn as nn import mindspore.context as context -from mindspore.common import dtype as mstype context.set_context(mode=context.GRAPH_MODE, device_target='CPU') -class Concat_Axis0(nn.Cell): - def __init__(self): - super(Concat_Axis0, self).__init__() - self.cat = P.Concat(axis=0) +class ConcatV10(nn.Cell): + def __init__(self, nptype): + super(ConcatV10, self).__init__() + + self.cat = P.Concat(axis=2) + self.x1 = Tensor(np.array([[[0., 0., 1.], + [1., 2., 3.]], + [[2., 4., 5.], + [3., 6., 7.]]]).astype(nptype)) + + def construct(self): + return self.cat((self.x1,)) + - def construct(self, x1, x2): - return self.cat((x1, x2)) +def axis10(nptype): + cat = ConcatV10(nptype) + output = cat() + expect = np.array([[[0., 0., 1.], + [1., 2., 3.]], + [[2., 4., 5.], + [3., 6., 7.]]]).astype(nptype) + print(output) + assert (output.asnumpy() == expect).all() @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_in2_axis0(): - x1 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32) - x2 = Tensor(np.arange(3 * 2 * 2).reshape(3, 2, 2), mstype.float32) - cat = Concat_Axis0() - output_ms = cat(x1, x2) - print("output:\n", output_ms) - output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=0) +def test_axis10_float32(): + axis10(np.float32) - error = np.ones(shape=output_np.shape) * 10e-6 - diff = output_ms.asnumpy() - output_np - assert np.all(diff < error) - assert np.all(-diff < error) +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis10_int32(): + axis10(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis10_bool(): + axis10(np.bool) + +class ConcatV32(nn.Cell): + def __init__(self, nptype): + super(ConcatV32, self).__init__() + + self.cat = P.Concat(axis=2) + self.x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(nptype)) + self.x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(nptype)) + + def construct(self): + return self.cat((self.x1, self.x2)) + + +def axis32(nptype): + cat = ConcatV32(nptype) + output = cat() + expect = np.array([[[0., 0., 1.], + [1., 2., 3.]], + [[2., 4., 5.], + [3., 6., 7.]]]).astype(nptype) + print(output) + assert (output.asnumpy() == expect).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis32_float32(): + axis32(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis32_int32(): + axis32(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis32_bool(): + axis32(np.bool) + + +class ConcatV43(nn.Cell): + def __init__(self, nptype): + super(ConcatV43, self).__init__() + + self.cat = P.Concat(axis=3) + self.x1 = Tensor(np.arange(2 * 2 * 2 * 2).reshape(2, 2, 2, 2).astype(nptype)) + self.x2 = Tensor(np.arange(2 * 2 * 2 * 3).reshape(2, 2, 2, 3).astype(nptype)) + + def construct(self): + return self.cat((self.x1, self.x2)) + + +def axis43(nptype): + cat = ConcatV43(nptype) + output = cat() + expect = np.array([[[[0., 1., 0., 1., 2.], + [2., 3., 3., 4., 5.]], + [[4., 5., 6., 7., 8.], + [6., 7., 9., 10., 11.]]], + [[[8., 9., 12., 13., 14.], + [10., 11., 15., 16., 17.]], + [[12., 13., 18., 19., 20.], + [14., 15., 21., 22., 23.]]]]).astype(nptype) + assert (output.asnumpy() == expect).all() + print(output) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis43_float32(): + axis43(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis43_int32(): + axis43(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis43_bool(): + axis43(np.bool) + + +class ConcatV21(nn.Cell): + def __init__(self, nptype): + super(ConcatV21, self).__init__() -class Concat_Axis1(nn.Cell): - def __init__(self): - super(Concat_Axis1, self).__init__() self.cat = P.Concat(axis=1) + self.x1 = Tensor(np.arange(2 * 2).reshape(2, 2).astype(nptype)) + self.x2 = Tensor(np.arange(2 * 3).reshape(2, 3).astype(nptype)) + + def construct(self): + return self.cat((self.x1, self.x2)) + - def construct(self, x1, x2): - return self.cat((x1, x2)) +def axis21(nptype): + cat = ConcatV21(nptype) + output = cat() + expect = np.array([[0., 1., 0., 1., 2.], + [2., 3., 3., 4., 5.]]).astype(nptype) + assert (output.asnumpy() == expect).all() + print(output) @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_in2_axis1(): - x1 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32) - x2 = Tensor(np.arange(2 * 3 * 2).reshape(2, 3, 2), mstype.float32) - cat = Concat_Axis1() - output_ms = cat(x1, x2) - print("output:\n", output_ms) - output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=1) +def test_axis21_float32(): + axis21(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis21_int32(): + axis21(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_axis21_bool(): + axis21(np.bool) - error = np.ones(shape=output_np.shape) * 10e-6 - diff = output_ms.asnumpy() - output_np - assert np.all(diff < error) - assert np.all(-diff < error) -class Concat_in3_Axis2(nn.Cell): +class Concat3INet(nn.Cell): def __init__(self): - super(Concat_in3_Axis2, self).__init__() - self.cat = P.Concat(axis=-1) + super(Concat3INet, self).__init__() + self.cat = P.Concat(axis=1) def construct(self, x1, x2, x3): return self.cat((x1, x2, x3)) + +def concat_3i(nptype): + cat = Concat3INet() + + x1_np = np.random.randn(32, 4, 224, 224).astype(nptype) + x2_np = np.random.randn(32, 8, 224, 224).astype(nptype) + x3_np = np.random.randn(32, 10, 224, 224).astype(nptype) + output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1) + + x1_ms = Tensor(x1_np) + x2_ms = Tensor(x2_np) + x3_ms = Tensor(x3_np) + output_ms = cat(x1_ms, x2_ms, x3_ms) + + error = np.ones(shape=output_np.shape) * 10e-6 + diff = output_ms.asnumpy() - output_np + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_concat_3i_float32(): + concat_3i(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_concat_3i_int32(): + concat_3i(np.int32) + @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard -def test_in3_axis2(): - x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32) - x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32) - x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32) - cat = Concat_in3_Axis2() - output_ms = cat(x1, x2, x3) - print("output:\n", output_ms) - output_np = np.concatenate((x1.asnumpy(), x2.asnumpy(), x3.asnumpy()), axis=-1) +def test_concat_3i_bool(): + cat = Concat3INet() + + x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool) + x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool) + x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool) + output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1) + + x1_ms = Tensor(x1_np) + x2_ms = Tensor(x2_np) + x3_ms = Tensor(x3_np) + output_ms = cat(x1_ms, x2_ms, x3_ms) + + assert (output_ms.asnumpy() == output_np).all() + + +class Concat4INet(nn.Cell): + def __init__(self): + super(Concat4INet, self).__init__() + self.cat = P.Concat(axis=1) + + def construct(self, x1, x2, x3, x4): + return self.cat((x1, x2, x3, x4)) + + +def concat_4i(nptype): + cat = Concat4INet() + + x1_np = np.random.randn(32, 4, 224, 224).astype(nptype) + x2_np = np.random.randn(32, 8, 224, 224).astype(nptype) + x3_np = np.random.randn(32, 10, 224, 224).astype(nptype) + x4_np = np.random.randn(32, 5, 224, 224).astype(nptype) + output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1) + + x1_ms = Tensor(x1_np) + x2_ms = Tensor(x2_np) + x3_ms = Tensor(x3_np) + x4_ms = Tensor(x4_np) + output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms) error = np.ones(shape=output_np.shape) * 10e-6 diff = output_ms.asnumpy() - output_np assert np.all(diff < error) - assert np.all(-diff < error) -if __name__ == '__main__': - test_in2_axis0() - test_in2_axis1() - test_in3_axis2() +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_concat_4i_float32(): + concat_4i(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_concat_4i_int32(): + concat_4i(np.int32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_concat_4i_bool(): + cat = Concat4INet() + + x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool) + x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool) + x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool) + x4_np = np.random.choice([True, False], (32, 5, 224, 224)).astype(np.bool) + output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1) + + x1_ms = Tensor(x1_np) + x2_ms = Tensor(x2_np) + x3_ms = Tensor(x3_np) + x4_ms = Tensor(x4_np) + output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms) + + assert (output_ms.asnumpy() == output_np).all()