add float64 support to concat GPU

pull/12350/head
TFBunny 4 years ago
parent f9a2b2004f
commit ebb7f2e771

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,6 +18,9 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
ConcatV2GpuFwdKernel, double)
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConcatV2GpuFwdKernel, float)

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_
#include <vector>
#include <memory>
@ -133,4 +133,4 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CONCATV2_GPU_KERNEL_H
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_CONCATV2_GPU_KERNEL_H_

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,9 +19,8 @@
#include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
template <typename T>
__global__ void Concat(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, T** inputs, T* output) {
__global__ void Concat(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis,
int *len_axis, T **inputs, T *output) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int num = pos % all_size_before_axis / all_size_axis;
int block = -1;
@ -37,45 +36,38 @@ __global__ void Concat(const size_t size, const int input_num,
}
block_len = len_axis[block];
axis_inc -= len_axis[block];
int block_pos = pos / all_size_before_axis * block_len * all_size_axis +
(num - axis_inc) * all_size_axis + pos % all_size_axis;;
int block_pos =
pos / all_size_before_axis * block_len * all_size_axis + (num - axis_inc) * all_size_axis + pos % all_size_axis;
output[pos] = inputs[block][block_pos];
}
return;
}
template <typename T>
void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, T** inputs, T* output,
cudaStream_t cuda_stream) {
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num,
all_size_before_axis, all_size_axis,
void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis,
int *len_axis, T **inputs, T *output, cudaStream_t cuda_stream) {
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input_num, all_size_before_axis, all_size_axis,
len_axis, inputs, output);
return;
}
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, float** inputs, float* output,
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, double **inputs, double *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, int** inputs, int* output,
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, float **inputs, float *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, half** inputs, half* output,
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, int **inputs, int *output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, half **inputs, half *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, short** inputs, short* output, // NOLINT
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, short **inputs, short *output, // NOLINT
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, unsigned char** inputs, unsigned char* output,
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, unsigned char **inputs, unsigned char *output,
cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, bool** inputs, bool* output,
template void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis,
const int all_size_axis, int *len_axis, bool **inputs, bool *output,
cudaStream_t cuda_stream);

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,13 +14,11 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis,
int* len_axis, T** inputs, T* output,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_
void ConcatKernel(const size_t size, const int input_num, const int all_size_before_axis, const int all_size_axis,
int *len_axis, T **inputs, T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_CONCATV2_IMPL_CUH_

@ -2162,7 +2162,7 @@ class Concat(PrimitiveWithInfer):
@prim_attr_register
def __init__(self, axis=0):
"""Initialize Tile"""
"""Initialize Concat"""
validator.check_value_type("axis", axis, [int], self.name)
def __infer__(self, input_x):

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -49,9 +49,14 @@ def axis32(nptype):
[1., 2., 3.]],
[[2., 4., 5.],
[3., 6., 7.]]]).astype(nptype)
print(output)
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_axis32_float64():
axis32(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -106,8 +111,12 @@ def axis43(nptype):
[[12., 13., 18., 19., 20.],
[14., 15., 21., 22., 23.]]]]).astype(nptype)
assert (output.asnumpy() == expect).all()
print(output)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_axis43_float64():
axis43(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -155,7 +164,12 @@ def axis21(nptype):
expect = np.array([[0., 1., 0., 1., 2.],
[2., 3., 3., 4., 5.]]).astype(nptype)
assert (output.asnumpy() == expect).all()
print(output)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_axis21_float64():
axis21(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@ -208,6 +222,12 @@ def concat_3i(nptype):
diff = output_ms.asnumpy() - output_np
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_3i_float64():
concat_3i(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
@ -273,6 +293,12 @@ def concat_4i(nptype):
diff = output_ms.asnumpy() - output_np
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_4i_float64():
concat_4i(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard

Loading…
Cancel
Save