diff --git a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc index af86ff8e9b..3bca6a69d3 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.cc @@ -19,15 +19,13 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE( - Concat, - KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ConcatV2GpuFwdKernel, float) +MS_REG_GPU_KERNEL_ONE(Concat, + KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ConcatV2GpuFwdKernel, int) MS_REG_GPU_KERNEL_ONE( - Concat, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ConcatV2GpuFwdKernel, int) -MS_REG_GPU_KERNEL_ONE( - Concat, - KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ConcatV2GpuFwdKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h index eba6bb87f0..5dabb3045c 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/concatv2_gpu_kernel.h @@ -27,7 +27,7 @@ namespace kernel { template class ConcatV2GpuFwdKernel : public GpuKernel { public: - ConcatV2GpuFwdKernel() : axis_(0), input0_size_(0), input1_size_(0), output_size_(0), workspace_size_(0) {} + ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {} ~ConcatV2GpuFwdKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -35,12 +35,32 @@ class ConcatV2GpuFwdKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs, uintptr_t stream_ptr) override { - T *input_0 = GetDeviceAddress(inputs, 0); - T *input_1 = GetDeviceAddress(inputs, 1); - T *output = GetDeviceAddress(outputs, 0); + if (inputs.size() == 2) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, + reinterpret_cast(stream_ptr)); + } + + if (inputs.size() == 3) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *input_2 = GetDeviceAddress(inputs, 2); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output, + reinterpret_cast(stream_ptr)); + } - CalConcatV2(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output, - reinterpret_cast(stream_ptr)); + if (inputs.size() == 4) { + T *input_0 = GetDeviceAddress(inputs, 0); + T *input_1 = GetDeviceAddress(inputs, 1); + T *input_2 = GetDeviceAddress(inputs, 2); + T *input_3 = GetDeviceAddress(inputs, 3); + T *output = GetDeviceAddress(outputs, 0); + ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output, + reinterpret_cast(stream_ptr)); + } return true; } bool Init(const CNodePtr &kernel_node) override { @@ -48,44 +68,44 @@ class ConcatV2GpuFwdKernel : public GpuKernel { return false; } - auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - input0_size_ = sizeof(T); - for (size_t i = 0; i < input_shape.size(); i++) { - input0_size_ *= input_shape[i]; - } - auto input_shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - input1_size_ = sizeof(T); - for (size_t i = 0; i < input_shape1.size(); i++) { - input1_size_ *= input_shape1[i]; - } - output_size_ = input0_size_ + input1_size_; axis_ = GetAttr(kernel_node, "axis"); if (axis_ < 0) { + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); axis_ += SizeToInt(input_shape.size()); } - w_[0] = 1; - w_[1] = 1; - for (size_t i = IntToSize(axis_); i < input_shape.size(); i++) { - w_[0] *= SizeToInt(input_shape[i]); - w_[1] *= SizeToInt(input_shape1[i]); + + auto input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t i = 0; i < input_num; i++) { + auto input_size = sizeof(T); + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); + for (size_t j = 0; j < input_shape.size(); j++) { + input_size *= SizeToInt(input_shape[j]); + if (j >= IntToSize(axis_)) { + w_[i] *= SizeToInt(input_shape[j]); + } + input_size_list_.push_back(input_size); + } } + auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + output_size_ = sizeof(T); + for (size_t i = 0; i < output_shape.size(); i++) { + output_size_ *= output_shape[i]; + } + output_size_list_.push_back(output_size_); + InitSizeLists(); return true; } protected: - void InitSizeLists() override { - input_size_list_.push_back(input0_size_); - input_size_list_.push_back(input1_size_); - output_size_list_.push_back(output_size_); - } + void InitSizeLists() override {} private: bool CheckParam(const CNodePtr &kernel_node) { size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 2) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs 2 inputs."; + if (input_num < 2 || input_num > 4) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4."; return false; } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); @@ -95,16 +115,12 @@ class ConcatV2GpuFwdKernel : public GpuKernel { } return true; } - int w_[2] = {1}; + int w_[4] = {1, 1, 1, 1}; int axis_; + size_t output_size_; std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; - - size_t input0_size_; - size_t input1_size_; - size_t output_size_; - size_t workspace_size_; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu index fa10494d9c..5cccf183ea 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cu @@ -19,7 +19,7 @@ #include #include "kernel/gpu/cuda_impl/concatv2_impl.cuh" template -__global__ void ConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { +__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { int n = pos / (w1 + w2); int m = pos % (w1 + w2); @@ -29,16 +29,80 @@ __global__ void ConcatV2(const size_t size, const int w1, const int w2, const T* } template -void CalConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, +__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int n = pos / (w1 + w2 + w3); + int m = pos % (w1 + w2 + w3); + output[pos] = m < w1 ? input_1[n * w1 + m] : + m < w1 + w2 ? input_2[n * w2 + m - w1] : + input_3[n * w3 + m - w1 - w2]; + } + return; +} + +template +__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + int n = pos / (w1 + w2 + w3 + w4); + int m = pos % (w1 + w2 + w3 + w4); + output[pos] = m < w1 ? input_1[n * w1 + m] : + m < w1 + w2 ? input_2[n * w2 + m - w1]: + m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]: + input_4[n * w4 + m - w1 - w2 - w3]; + } + return; +} + +template +void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, cudaStream_t cuda_stream) { - ConcatV2<<>>(size, w1, w2, input_1, input_2, output); + Concat<<>>(size, w1, w2, input_1, input_2, output); + return; +} + +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output, + cudaStream_t cuda_stream) { + Concat<<>>(size, w1, w2, w3, input_1, input_2, input_3, output); return; } -template void CalConcatV2(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, - float* output, cudaStream_t cuda_stream); -template void CalConcatV2(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, - int* output, cudaStream_t cuda_stream); -template void CalConcatV2(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, - half* output, cudaStream_t cuda_stream); +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, + cudaStream_t cuda_stream) { + Concat<<>>(size, w1, w2, w3, w4, input_1, + input_2, input_3, input_4, output); + return; +} + +template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2, + half* output, cudaStream_t cuda_stream); + +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const float* input_1, const float* input_2, const float* input_3, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const int* input_1, const int* input_2, const int* input_3, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const half* input_1, const half* input_2, const half* input_3, + half* output, cudaStream_t cuda_stream); + +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const float* input_1, const float* input_2, const float* input_3, const float* input_4, + float* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const int* input_1, const int* input_2, const int* input_3, const int* input_4, + int* output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const half* input_1, const half* input_2, const half* input_3, const half* input_4, + half* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh index 5cbf61205b..b6932aa4a1 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/concatv2_impl.cuh @@ -19,7 +19,13 @@ #include "device/gpu/cuda_common.h" template -void CalConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, - cudaStream_t cuda_stream); - +void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output, + cudaStream_t cuda_stream); +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, + const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream); +template +void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4, + const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output, + cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ diff --git a/tests/st/ops/gpu/test_concatv2_op.py b/tests/st/ops/gpu/test_concatv2_op.py index c1934762b4..02e2258ffd 100644 --- a/tests/st/ops/gpu/test_concatv2_op.py +++ b/tests/st/ops/gpu/test_concatv2_op.py @@ -113,3 +113,62 @@ def test_axis21(): [2., 3., 3., 4., 5.]] assert (output.asnumpy() == expect).all() print(output) + +class Concat3INet(nn.Cell): + def __init__(self): + super(Concat3INet, self).__init__() + self.cat = P.Concat(axis=1) + + def construct(self, x1, x2, x3): + return self.cat((x1, x2, x3)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_3i(): + cat = Concat3INet() + + x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32) + x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32) + x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32) + output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1) + + x1_ms = Tensor(x1_np) + x2_ms = Tensor(x2_np) + x3_ms = Tensor(x3_np) + output_ms = cat(x1_ms, x2_ms, x3_ms) + + error = np.ones(shape=output_np.shape) * 10e-6 + diff = output_ms.asnumpy() - output_np + assert np.all(diff < error) + + +class Concat4INet(nn.Cell): + def __init__(self): + super(Concat4INet, self).__init__() + self.cat = P.Concat(axis=1) + + def construct(self, x1, x2, x3, x4): + return self.cat((x1, x2, x3, x4)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_4i(): + cat = Concat4INet() + + x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32) + x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32) + x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32) + x4_np = np.random.randn(32, 5, 224, 224).astype(np.float32) + output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1) + + x1_ms = Tensor(x1_np) + x2_ms = Tensor(x2_np) + x3_ms = Tensor(x3_np) + x4_ms = Tensor(x4_np) + output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms) + + error = np.ones(shape=output_np.shape) * 10e-6 + diff = output_ms.asnumpy() - output_np + assert np.all(diff < error)