diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu new file mode 100644 index 0000000000..887515b05e --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cu @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" + +template +__global__ void CalReLUGradKernel(int size, T *dy, T *y, T *dx) { + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + dx[pos] = y[pos] > static_cast(0) ? dy[pos] : static_cast(0); + } +} + +template +void CalReLUGrad(int size, T *dy, T *y, T *dx, cudaStream_t cuda_stream) { + CalReLUGradKernel<<>>(size, dy, y, dx); + return; +} + +template void CalReLUGrad(int size, float *dy, float *y, float *dx, cudaStream_t cuda_stream); +template void CalReLUGrad(int size, half *dy, half *y, half *dx, cudaStream_t cuda_stream); +template void CalReLUGrad(int size, int8_t *dy, int8_t *y, int8_t *dx, cudaStream_t cuda_stream); +template void CalReLUGrad(int size, int32_t *dy, int32_t *y, int32_t *dx, cudaStream_t cuda_stream); +template void CalReLUGrad(int size, int64_t *dy, int64_t *y, int64_t *dx, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh new file mode 100644 index 0000000000..1d1fbbde7c --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ + +#include "runtime/device/gpu/cuda_common.h" +template +void CalReLUGrad(int input_size, T *dy, T *y, T *dx, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_RELU_GRAD_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu index 57c2bcdee9..d0c0b5f526 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/relu_impl.cu @@ -33,6 +33,7 @@ void CalReLU(int size, T *input_addr, T *output_addr, cudaStream_t cuda_stream) template void CalReLU(int size, float *input_addr, float *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, half *input_addr, half *output_addr, cudaStream_t cuda_stream); +template void CalReLU(int size, int8_t *input_addr, int8_t *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, int32_t *input_addr, int32_t *output_addr, cudaStream_t cuda_stream); template void CalReLU(int size, int64_t *input_addr, int64_t *output_addr, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc index f094bd064d..436da8bdba 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_gpu_kernel.cc @@ -22,6 +22,8 @@ MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut ActivationGpuFwdKernel, float) MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ActivationGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + ActivationGpuFwdKernel, int8_t) MS_REG_GPU_KERNEL_ONE(ReLU, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), ActivationGpuFwdKernel, int32_t) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc index c3ab7c1cfd..63c10a5525 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.cc @@ -26,6 +26,12 @@ MS_REG_GPU_KERNEL_ONE( ReluGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), ActivationGradGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + ActivationGradGpuKernel, int32_t) +MS_REG_GPU_KERNEL_ONE( + ReluGrad, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), + ActivationGradGpuKernel, int8_t) MS_REG_GPU_KERNEL_ONE( ReLU6Grad, diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h index bcfb61c58d..c35fe5a70c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/activation_grad_kernel.h @@ -23,6 +23,7 @@ #include "backend/kernel_compiler/gpu/gpu_kernel.h" #include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" #include "backend/kernel_compiler/gpu/kernel_constants.h" +#include "backend/kernel_compiler/gpu/cuda_impl/relu_grad_impl.cuh" namespace mindspore { namespace kernel { @@ -36,7 +37,7 @@ class ActivationGradGpuKernel : public GpuKernel { const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } bool Launch(const std::vector &inputs, const std::vector &, - const std::vector &outputs, void *) override { + const std::vector &outputs, void *stream_ptr) override { if (is_null_input_) { return true; } @@ -51,13 +52,18 @@ class ActivationGradGpuKernel : public GpuKernel { } T *dx = GetDeviceAddress(outputs, 0); - const float alpha = 1; - const float beta = 0; - CHECK_CUDNN_RET_WITH_EXCEPT( - kernel_node_, - cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, - data_descriptor_, y, &beta, data_descriptor_, dx), - "cudnnActivationBackward failed"); + if (mode_ == CUDNN_ACTIVATION_RELU) { + const int size = input_size_ / sizeof(T); + CalReLUGrad(size, dy, y, dx, reinterpret_cast(stream_ptr)); + } else { + const float alpha = 1; + const float beta = 0; + CHECK_CUDNN_RET_WITH_EXCEPT( + kernel_node_, + cudnnActivationBackward(cudnn_handle_, activation_desc_, &alpha, data_descriptor_, y, data_descriptor_, dy, + data_descriptor_, y, &beta, data_descriptor_, dx), + "cudnnActivationBackward failed"); + } return true; } diff --git a/tests/st/ops/gpu/test_relu_grad_op.py b/tests/st/ops/gpu/test_relu_grad_op.py index c63e492038..43647c0180 100644 --- a/tests/st/ops/gpu/test_relu_grad_op.py +++ b/tests/st/ops/gpu/test_relu_grad_op.py @@ -31,17 +31,14 @@ class NetReluGrad(nn.Cell): return self.rekuGrad(dy, x) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_relu_grad(): +def relu_grad_base(dtype): x = Tensor(np.array([[[[-1, 1, 1], [1, -1, 1], - [1, 1, -1]]]]).astype(np.float32)) + [1, 1, -1]]]]).astype(dtype)) dy = Tensor(np.array([[[[1, 0, 1], [0, 1, 0], - [1, 1, 1]]]]).astype(np.float32)) - expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.float32) + [1, 1, 1]]]]).astype(dtype)) + expect = np.array([[[[0, 0, 1,], [0, 0, 0,], [1, 1, 0.]]]]).astype(np.dtype) error = np.ones(shape=[3, 3]) * 1.0e-6 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -49,3 +46,39 @@ def test_relu_grad(): output = relu_grad(x, dy) diff = output.asnumpy() - expect assert np.all(diff < error) + assert output.asnumpy().dtype == dtype + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_float16(): + relu_grad_base(np.float16) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_float32(): + relu_grad_base(np.float32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_int8(): + relu_grad_base(np.int8) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_int32(): + relu_grad_base(np.int32) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_int64(): + relu_grad_base(np.int64) diff --git a/tests/st/ops/gpu/test_relu_op.py b/tests/st/ops/gpu/test_relu_op.py index 03443c0e65..b0dc21b0a6 100644 --- a/tests/st/ops/gpu/test_relu_op.py +++ b/tests/st/ops/gpu/test_relu_op.py @@ -65,6 +65,28 @@ def test_relu_float32(): assert (output.asnumpy() == expect).all() +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_int8(): + x = Tensor(np.array([[[[-1, 1, 10], + [1, -1, 1], + [10, 1, -1]]]]).astype(np.int8)) + expect = np.array([[[[0, 1, 10,], + [1, 0, 1,], + [10, 1, 0.]]]]).astype(np.int8) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + relu = NetRelu() + output = relu(x) + assert (output.asnumpy() == expect).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + relu = NetRelu() + output = relu(x) + assert (output.asnumpy() == expect).all() + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard