!12350 add float64 support to concat GPU

From: @TFbunny
Reviewed-by: @tom__chen,@robingrosman
Signed-off-by: @robingrosman
pull/12350/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 8fabb26412

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,6 +18,9 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ConcatV2GpuFwdKernel, double)
MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConcatV2GpuFwdKernel, float) ConcatV2GpuFwdKernel, float)

@ -14,8 +14,8 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_
#include <vector> #include <vector>
#include <memory> #include <memory>
@ -133,4 +133,4 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -19,9 +19,8 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
template <typename T> template <typename T>
__global__ void Concat(const size_t size, const int input_num, __global__ void Concat(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis,
const int all_size_before_axis, const int all_size_axis, int *len_axis, T **inputs, T *output) {
int* len_axis, T** inputs, T* output) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int num = pos % all_size_before_axis / all_size_axis; int num = pos % all_size_before_axis / all_size_axis;
int block = -1; int block = -1;
@ -37,45 +36,38 @@ __global__ void Concat(const size_t size, const int input_num,
} }
block_len = len_axis[block]; block_len = len_axis[block];
axis_inc -= len_axis[block]; axis_inc -= len_axis[block];
int block_pos = pos / all_size_before_axis * block_len * all_size_axis + int block_pos =
(num - axis_inc) * all_size_axis + pos % all_size_axis;; pos / all_size_before_axis * block_len * all_size_axis + (num - axis_inc) * all_size_axis + pos % all_size_axis;
output[pos] = inputs[block][block_pos]; output[pos] = inputs[block][block_pos];
} }
return; return;
} }
template <typename T> template <typename T>
void ConcatKernel(const size_t size, const int input_num, void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis,
const int all_size_before_axis, const int all_size_axis, int *len_axis, T **inputs, T *output, cudaStream_t cuda_stream) {
int* len_axis, T** inputs, T* output, Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, all_size_before_axis, all_size_axis,
cudaStream_t cuda_stream) {
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num,
all_size_before_axis, all_size_axis,
len_axis, inputs, output); len_axis, inputs, output);
return; return;
} }
template void ConcatKernel(const size_t size, const int input_num, template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_before_axis, const int all_size_axis, const int all_size_axis, int *len_axis, double **inputs, double *output,
int* len_axis, float** inputs, float* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_before_axis, const int all_size_axis, const int all_size_axis, int *len_axis, float **inputs, float *output,
int* len_axis, int** inputs, int* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_before_axis, const int all_size_axis, const int all_size_axis, int *len_axis, int **inputs, int *output, cudaStream_t cuda_stream);
int* len_axis, half** inputs, half* output, template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, half **inputs, half *output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_before_axis, const int all_size_axis, const int all_size_axis, int *len_axis, short **inputs, short *output, // NOLINT
int* len_axis, short** inputs, short* output, // NOLINT
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_before_axis, const int all_size_axis, const int all_size_axis, int *len_axis, unsigned char **inputs, unsigned char *output,
int* len_axis, unsigned char** inputs, unsigned char* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_before_axis, const int all_size_axis, const int all_size_axis, int *len_axis, bool **inputs, bool *output,
int* len_axis, bool** inputs, bool* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);

@ -1,5 +1,5 @@
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -14,13 +14,11 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_ #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
void ConcatKernel(const size_t size, const int input_num, void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis,
const int all_size_before_axis, const int all_size_axis, int *len_axis, T **inputs, T *output, cudaStream_t cuda_stream);
int* len_axis, T** inputs, T* output, #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_

@ -2162,7 +2162,7 @@ class Concat(PrimitiveWithInfer):
@prim_attr_register @prim_attr_register
def __init__(self, axis=0): def __init__(self, axis=0):
"""Initialize Tile""" """Initialize Concat"""
validator.check_value_type("axis", axis, [int], self.name) validator.check_value_type("axis", axis, [int], self.name)
def __infer__(self, input_x): def __infer__(self, input_x):

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -49,9 +49,14 @@ def axis32(nptype):
[1., 2., 3.]], [1., 2., 3.]],
[[2., 4., 5.], [[2., 4., 5.],
[3., 6., 7.]]]).astype(nptype) [3., 6., 7.]]]).astype(nptype)
print(output)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_axis32_float64():
axis32(np.float64)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -106,8 +111,12 @@ def axis43(nptype):
[[12., 13., 18., 19., 20.], [[12., 13., 18., 19., 20.],
[14., 15., 21., 22., 23.]]]]).astype(nptype) [14., 15., 21., 22., 23.]]]]).astype(nptype)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
print(output)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_axis43_float64():
axis43(np.float64)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@ -155,7 +164,12 @@ def axis21(nptype):
expect = np.array([[0., 1., 0., 1., 2.], expect = np.array([[0., 1., 0., 1., 2.],
[2., 3., 3., 4., 5.]]).astype(nptype) [2., 3., 3., 4., 5.]]).astype(nptype)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
print(output)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_axis21_float64():
axis21(np.float64)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@ -208,6 +222,12 @@ def concat_3i(nptype):
diff = output_ms.asnumpy() - output_np diff = output_ms.asnumpy() - output_np
assert np.all(diff < error) assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_3i_float64():
concat_3i(np.float64)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
@ -273,6 +293,12 @@ def concat_4i(nptype):
diff = output_ms.asnumpy() - output_np diff = output_ms.asnumpy() - output_np
assert np.all(diff < error) assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_4i_float64():
concat_4i(np.float64)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard

Loading…
Cancel
Save