From ebb7f2e7716b7dc39c9a66f0f796828681cc90cc Mon Sep 17 00:00:00 2001 From: TFBunny Date: Thu, 11 Feb 2021 11:03:10 -0500 Subject: [PATCH] add float64 support to concat GPU --- .../gpu/arrays/concatv2_gpu_kernel.cc | 5 +- .../gpu/arrays/concatv2_gpu_kernel.h | 6 +-- .../gpu/cuda_impl/concatv2_impl.cu | 52 ++++++++----------- .../gpu/cuda_impl/concatv2_impl.cuh | 14 +++-- mindspore/ops/operations/array_ops.py | 2 +- tests/st/ops/gpu/test_concatv2_op.py | 34 ++++++++++-- 6 files changed, 66 insertions(+), 47 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc index e834b9b592..83377cc5a4 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,9 @@ namespace mindspore { namespace kernel { +MS_REG_GPU_KERNEL_ONE( + Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + ConcatV2GpuFwdKernel, double) MS_REG_GPU_KERNEL_ONE( Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ConcatV2GpuFwdKernel, float) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h index 4a2001d589..88d4c3583f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_ #include #include @@ -133,4 +133,4 @@ class ConcatV2GpuFwdKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu index 032626336e..3ddafa608e 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,8 @@ #include #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" template -__global__ void Concat(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, T** inputs, T* output) { +__global__ void Concat(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis, + int *len_axis, T **inputs, T *output) { for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { int num = pos % all_size_before_axis / all_size_axis; int block = -1; @@ -37,45 +36,38 @@ __global__ void Concat(const size_t size, const int input_num, } block_len = len_axis[block]; axis_inc -= len_axis[block]; - int block_pos = pos / all_size_before_axis * block_len * all_size_axis + - (num - axis_inc) * all_size_axis + pos % all_size_axis;; + int block_pos = + pos / all_size_before_axis * block_len * all_size_axis + (num - axis_inc) * all_size_axis + pos % all_size_axis; output[pos] = inputs[block][block_pos]; } return; } template -void ConcatKernel(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, T** inputs, T* output, - cudaStream_t cuda_stream) { - Concat<<>>(size, input_num, - all_size_before_axis, all_size_axis, +void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis, + int *len_axis, T **inputs, T *output, cudaStream_t cuda_stream) { + Concat<<>>(size, input_num, all_size_before_axis, all_size_axis, len_axis, inputs, output); return; } -template void ConcatKernel(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, float** inputs, float* output, +template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, + const int all_size_axis, int *len_axis, double **inputs, double *output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, int** inputs, int* output, +template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, + const int all_size_axis, int *len_axis, float **inputs, float *output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, half** inputs, half* output, +template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, + const int all_size_axis, int *len_axis, int **inputs, int *output, cudaStream_t cuda_stream); +template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, + const int all_size_axis, int *len_axis, half **inputs, half *output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, short** inputs, short* output, // NOLINT +template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, + const int all_size_axis, int *len_axis, short **inputs, short *output, // NOLINT cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, unsigned char** inputs, unsigned char* output, +template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, + const int all_size_axis, int *len_axis, unsigned char **inputs, unsigned char *output, cudaStream_t cuda_stream); -template void ConcatKernel(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, bool** inputs, bool* output, +template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, + const int all_size_axis, int *len_axis, bool **inputs, bool *output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh index 6e469e8028..a37de1c454 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,13 +14,11 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_ #include "runtime/device/gpu/cuda_common.h" template -void ConcatKernel(const size_t size, const int input_num, - const int all_size_before_axis, const int all_size_axis, - int* len_axis, T** inputs, T* output, - cudaStream_t cuda_stream); -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ +void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis, + int *len_axis, T **inputs, T *output, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_ diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b14a4d8977..2cb561f6f4 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -2162,7 +2162,7 @@ class Concat(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0): - """Initialize Tile""" + """Initialize Concat""" validator.check_value_type("axis", axis, [int], self.name) def __infer__(self, input_x): diff --git a/tests/st/ops/gpu/test_concatv2_op.py b/tests/st/ops/gpu/test_concatv2_op.py index 0b35652eec..5ae7c7a498 100644 --- a/tests/st/ops/gpu/test_concatv2_op.py +++ b/tests/st/ops/gpu/test_concatv2_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -49,9 +49,14 @@ def axis32(nptype): [1., 2., 3.]], [[2., 4., 5.], [3., 6., 7.]]]).astype(nptype) - print(output) assert (output.asnumpy() == expect).all() +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis32_float64(): + axis32(np.float64) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -106,8 +111,12 @@ def axis43(nptype): [[12., 13., 18., 19., 20.], [14., 15., 21., 22., 23.]]]]).astype(nptype) assert (output.asnumpy() == expect).all() - print(output) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis43_float64(): + axis43(np.float64) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -155,7 +164,12 @@ def axis21(nptype): expect = np.array([[0., 1., 0., 1., 2.], [2., 3., 3., 4., 5.]]).astype(nptype) assert (output.asnumpy() == expect).all() - print(output) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_axis21_float64(): + axis21(np.float64) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -208,6 +222,12 @@ def concat_3i(nptype): diff = output_ms.asnumpy() - output_np assert np.all(diff < error) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_3i_float64(): + concat_3i(np.float64) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -273,6 +293,12 @@ def concat_4i(nptype): diff = output_ms.asnumpy() - output_np assert np.all(diff < error) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_concat_4i_float64(): + concat_4i(np.float64) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard