diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc index 77e7de6fef..fb462d84d1 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/transpose_gpu_kernel.cc @@ -21,5 +21,7 @@ MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).A TransposeGpuFwdKernel, float) MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), TransposeGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + TransposeGpuFwdKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu index 8a305fa974..1e1a9e2da6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cu @@ -15,6 +15,7 @@ */ #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_grad_impl.cuh" +#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh" #include "runtime/device/gpu/cuda_common.h" template @@ -22,9 +23,9 @@ struct MinimumGradFunc { __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, const T &dy, T *dx1, T *dx2) { if (grad_x1 && x1 < x2) { - atomicAdd(dx1, dy); + ms_atomic_add(dx1, dy); } else if (grad_x2 && x1 >= x2) { - atomicAdd(dx2, dy); + ms_atomic_add(dx2, dy); } } }; @@ -34,9 +35,9 @@ struct MaximumGradFunc { __device__ __forceinline__ void operator()(const T &x1, const T &x2, const bool &grad_x1, const bool &grad_x2, const T &dy, T *dx1, T *dx2) { if (grad_x1 && x1 > x2) { - atomicAdd(dx1, dy); + ms_atomic_add(dx1, dy); } else if (grad_x2 && x1 <= x2) { - atomicAdd(dx2, dy); + ms_atomic_add(dx2, dy); } } }; @@ -117,6 +118,9 @@ template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool & cudaStream_t stream); template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream); +template void NoBroadcastGrad(const int &nums, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, + const half *x1, const half *x2, const half *dy, half *dx1, half *dx2, + cudaStream_t stream); template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const float *x1, @@ -125,3 +129,7 @@ template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const i const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const int *x1, const int *x2, const int *dy, int *dx1, int *dx2, cudaStream_t stream); +template void BroadcastGrad(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, + const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, + const bool &grad_x1, const bool &grad_x2, enum BroadcastGradOpType op, const half *x1, + const half *x2, const half *dy, half *dx1, half *dx2, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu index ffcb2c8052..fe38188930 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/transpose_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,10 @@ */ #include + #include "transpose_impl.cuh" #include "runtime/device/gpu/cuda_common.h" + template __global__ void Transpose(const int size, const T* input, const int* input_shape, const int* input_axis, const int shape_size, T* output) { @@ -63,3 +65,5 @@ template void CalTranspose(const int size, const float* input, const int* const int shape_size, float* output, cudaStream_t cuda_stream); template void CalTranspose(const int size, const half* input, const int* input_shape, const int* input_axis, const int shape_size, half* output, cudaStream_t cuda_stream); +template void CalTranspose(const int size, const int* input, const int* input_shape, const int* input_axis, + const int shape_size, int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc index 49be2fd9a6..a7e0eeebfd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_grad_gpu_kernel.cc @@ -34,6 +34,22 @@ MS_REG_GPU_KERNEL_ONE(MaximumGrad, .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), BroadcastOpGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(MinimumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(MaximumGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + BroadcastOpGradGpuKernel, half) MS_REG_GPU_KERNEL_ONE(MinimumGrad, KernelAttr() .AddInputAttr(kNumberTypeInt32) diff --git a/tests/st/ops/gpu/test_transpose_op.py b/tests/st/ops/gpu/test_transpose_op.py index 87c2d86ee6..44f5a422c7 100644 --- a/tests/st/ops/gpu/test_transpose_op.py +++ b/tests/st/ops/gpu/test_transpose_op.py @@ -28,25 +28,25 @@ context.set_context(device_target='GPU') class Transpose(nn.Cell): - def __init__(self): + def __init__(self, nptype): super(Transpose, self).__init__() self.transpose = P.Transpose() - self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(np.float32)), [5, 6]), + self.x_2D = Parameter(initializer(Tensor(np.arange(5 * 6).reshape(5, 6).astype(nptype)), [5, 6]), name='x_2D') self.perm_2D = (1, 0) - self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.float32)), [2, 2, 4]), + self.x_3D = Parameter(initializer(Tensor(np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(nptype)), [2, 2, 4]), name='x_3D') self.perm_3D = (1, 0, 2) self.x_4D = Parameter( - initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(np.float32)), [2, 3, 4, 5]), + initializer(Tensor(np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).astype(nptype)), [2, 3, 4, 5]), name='x_4D') self.perm_4D = (0, 1, 2, 3) self.x_5D = Parameter( - initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(np.float32)), + initializer(Tensor(np.arange(1 * 2 * 3 * 4 * 5).reshape(1, 2, 3, 4, 5).astype(nptype)), [1, 2, 3, 4, 5]), name='x_5D') self.perm_5D = (1, 0, 3, 4, 2) @@ -56,11 +56,8 @@ class Transpose(nn.Cell): self.transpose(self.x_4D, self.perm_4D), self.transpose(self.x_5D, self.perm_5D)) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_transpose(): - transpose = Transpose() +def transpose1(nptype): + transpose = Transpose(nptype) output = transpose() expect0 = np.array([[[0, 6, 12, 18, 24], @@ -68,11 +65,11 @@ def test_transpose(): [2, 8, 14, 20, 26], [3, 9, 15, 21, 27], [4, 10, 16, 22, 28], - [5, 11, 17, 23, 29]]]).astype(np.float32) + [5, 11, 17, 23, 29]]]).astype(nptype) expect1 = np.array([[[[0, 1, 2, 3], [8, 9, 10, 11]], [[4, 5, 6, 7], - [12, 13, 14, 15]]]]).astype(np.float32) + [12, 13, 14, 15]]]]).astype(nptype) expect2 = np.array([[[[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], @@ -97,7 +94,7 @@ def test_transpose(): [[100, 101, 102, 103, 104], [105, 106, 107, 108, 109], [110, 111, 112, 113, 114], - [115, 116, 117, 118, 119]]]]]).astype(np.float32) + [115, 116, 117, 118, 119]]]]]).astype(nptype) expect3 = np.array([[[[[[0, 20, 40], [1, 21, 41], [2, 22, 42], @@ -138,8 +135,26 @@ def test_transpose(): [76, 96, 116], [77, 97, 117], [78, 98, 118], - [79, 99, 119]]]]]]).astype(np.float32) + [79, 99, 119]]]]]]).astype(nptype) assert (output[0].asnumpy() == expect0).all() assert (output[1].asnumpy() == expect1).all() assert (output[2].asnumpy() == expect2).all() assert (output[3].asnumpy() == expect3).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_transpose_float32(): + transpose1(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_transpose_float16(): + transpose1(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_transpose_int32(): + transpose1(np.int32)