diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu index d3523ecf6b..3cc485b752 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ __global__ void SqrtGradKernel(const T *input, const T *dout, T *output, const s } return; } + template __global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -37,6 +38,7 @@ __global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const } return; } + template __global__ void AsinGradKernel(const T *input, const T *dout, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -46,6 +48,7 @@ __global__ void AsinGradKernel(const T *input, const T *dout, T *output, const s } return; } + template <> __global__ void AsinGradKernel(const half *input, const half *dout, half *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -55,6 +58,7 @@ __global__ void AsinGradKernel(const half *input, const half *dout, half *output } return; } + template __global__ void ACosGradKernel(const T *input, const T *dout, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -65,6 +69,7 @@ __global__ void ACosGradKernel(const T *input, const T *dout, T *output, const s } return; } + template <> __global__ void ACosGradKernel(const half *input, const half *dout, half *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -75,6 +80,7 @@ __global__ void ACosGradKernel(const half *input, const half *dout, half *output } return; } + template __global__ void AtanGradKernel(const T *input, const T *dout, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -84,6 +90,7 @@ __global__ void AtanGradKernel(const T *input, const T *dout, T *output, const s } return; } + template __global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -93,6 +100,7 @@ __global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const } return; } + template __global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -102,11 +110,24 @@ __global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const } return; } + +template +__global__ void ReciprocalGradKernel(const T *input, const T *dout, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + float inputf = static_cast(input[i]); + float doutf = static_cast(dout[i]); + float res = -1 * doutf * inputf * inputf; + output[i] = static_cast(res); + } + return; +} + template void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { SqrtGradKernel<<>>(input, dout, output, count); return; } + template void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { RsqrtGradKernel<<>>(input, dout, output, count); @@ -143,20 +164,28 @@ void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cud return; } +template +void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + ReciprocalGradKernel<<>>(input, dout, output, count); + return; +} + template void SqrtGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); template void RsqrtGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); template void AsinGrad(const float *input, const float *dout, float *output, const size_t count, - cudaStream_t cuda_stream); + cudaStream_t cuda_stream); template void ACosGrad(const float *input, const float *dout, float *output, const size_t count, - cudaStream_t cuda_stream); + cudaStream_t cuda_stream); template void AtanGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); template void AsinhGrad(const float *input, const float *dout, float *output, const size_t count, - cudaStream_t cuda_stream); + cudaStream_t cuda_stream); template void AcoshGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); +template void ReciprocalGrad(const float *input, const float *dout, float *output, const size_t count, + cudaStream_t cuda_stream); template void SqrtGrad(const half *input, const half *dout, half *output, const size_t count, cudaStream_t cuda_stream); template void RsqrtGrad(const half *input, const half *dout, half *output, const size_t count, @@ -164,10 +193,12 @@ template void RsqrtGrad(const half *input, const half *dout, half *output, template void AsinGrad(const half *input, const half *dout, half *output, const size_t count, cudaStream_t cuda_stream); template void ACosGrad(const half *input, const half *dout, half *output, const size_t count, - cudaStream_t cuda_stream); + cudaStream_t cuda_stream); template void AtanGrad(const half *input, const half *dout, half *output, const size_t count, cudaStream_t cuda_stream); template void AsinhGrad(const half *input, const half *dout, half *output, const size_t count, - cudaStream_t cuda_stream); + cudaStream_t cuda_stream); template void AcoshGrad(const half *input, const half *dout, half *output, const size_t count, cudaStream_t cuda_stream); +template void ReciprocalGrad(const half *input, const half *dout, half *output, const size_t count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh index 8e636d5bd2..63cd9d4673 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,7 @@ template void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); template void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); - +template +void ReciprocalGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc index 17982ca507..116168f9e5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,5 +74,13 @@ MS_REG_GPU_KERNEL_ONE( AcoshGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryGradOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ReciprocalGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryGradOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + ReciprocalGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryGradOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h index 0a93d7473d..87ee6c88bd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,12 +35,14 @@ enum UnaryGradOptype { UNARY_OP_ATAN_GRAD = 4, UNARY_OP_ASINH_GRAD = 5, UNARY_OP_ACOSH_GRAD = 6, + UNARY_OP_RECIPROCAL_GRAD = 7, UNARY_OP_GRAD_INVALID_TYPE = 255 }; static const std::map kUnaryGradOpTypeMap = { - {"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, {"AsinGrad", UNARY_OP_ASIN_GRAD}, - {"ACosGrad", UNARY_OP_ACOS_GRAD}, {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD}, - {"AcoshGrad", UNARY_OP_ACOSH_GRAD}}; + {"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, + {"AsinGrad", UNARY_OP_ASIN_GRAD}, {"ACosGrad", UNARY_OP_ACOS_GRAD}, + {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD}, + {"AcoshGrad", UNARY_OP_ACOSH_GRAD}, {"ReciprocalGrad", UNARY_OP_RECIPROCAL_GRAD}}; template class UnaryGradOpGpuKernel : public GpuKernel { @@ -101,6 +103,11 @@ class UnaryGradOpGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_RECIPROCAL_GRAD: { + ReciprocalGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } default: { MS_LOG(EXCEPTION) << "Unary grad operation " << unary_grad_op_type_ << " is not supported."; } diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index f66ec3df21..2f1a1ea325 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -448,22 +448,11 @@ def get_bprop_rsqrt(self): @bprop_getters.register(P.Reciprocal) def get_bprop_reciprocal(self): """Grad definition for `Reciprocal` operation.""" - if self.target == "GPU": - neg = P.Neg() - mul = P.Mul() - square = P.Square() - reciprocal = P.Reciprocal() - - def bprop(x, out, dout): - g = neg(reciprocal(square(x))) - dx = mul(dout, g) - return (dx,) - else: - reciprocal_grad = G.ReciprocalGrad() + reciprocal_grad = G.ReciprocalGrad() - def bprop(x, out, dout): - dx = reciprocal_grad(out, dout) - return (dx,) + def bprop(x, out, dout): + dx = reciprocal_grad(out, dout) + return (dx,) return bprop diff --git a/tests/st/ops/gpu/test_reciprocal_grad_op.py b/tests/st/ops/gpu/test_reciprocal_grad_op.py new file mode 100644 index 0000000000..eb2f9cf568 --- /dev/null +++ b/tests/st/ops/gpu/test_reciprocal_grad_op.py @@ -0,0 +1,91 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class NetReciprocalGrad(nn.Cell): + def __init__(self): + super(NetReciprocalGrad, self).__init__() + self.grad = G.ReciprocalGrad() + + def construct(self, y, dy): + return self.grad(y, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_reciprocal_grad_float32(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + y = Tensor(np.array([[[[-1, 1, 12], + [5, 34, 6], + [10, 2, -1]]]]).astype(np.float32)) + dy = Tensor(np.array([[[[29, 1, 55], + [2.2, 63, 2], + [3, 3, 12]]]]).astype(np.float32)) + expect = np.array([[[[-29, -1, -7920], + [-55, -72828, -72], + [-300, -12, -12]]]]).astype(np.float32) + net = NetReciprocalGrad() + output = net(y, dy) + np.testing.assert_array_almost_equal(output.asnumpy(), expect) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + y = Tensor(np.array([[[[-1, 1, 12], + [5, 34, 6], + [10, 2, -1]]]]).astype(np.float32)) + dy = Tensor(np.array([[[[29, 1, 55], + [2.2, 63, 2], + [3, 3, 12]]]]).astype(np.float32)) + expect = np.array([[[[-29, -1, -7920], + [-55, -72828, -72], + [-300, -12, -12]]]]).astype(np.float32) + net = NetReciprocalGrad() + output = net(y, dy) + np.testing.assert_array_almost_equal(output.asnumpy(), expect) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_reciprocal_grad_float16(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + y = Tensor(np.array([[0.01, 0.2, 0.22], + [10.002, 2, -1]]).astype(np.float16)) + dy = Tensor(np.array([[34, 1, 55], + [3, 3, 63]]).astype(np.float16)) + expect = np.array([[-0.0034, -0.03998, -2.662], + [-300, -12, -63]]).astype(np.float16) + net = NetReciprocalGrad() + output = net(y, dy) + np.testing.assert_array_almost_equal(output.asnumpy(), expect) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + y = Tensor(np.array([[0.01, 0.2, 0.22], + [10.002, 2, -1]]).astype(np.float16)) + dy = Tensor(np.array([[34, 1, 55], + [3, 3, 63]]).astype(np.float16)) + expect = np.array([[-0.0034, -0.03998, -2.662], + [-300, -12, -63]]).astype(np.float16) + net = NetReciprocalGrad() + output = net(y, dy) + np.testing.assert_array_almost_equal(output.asnumpy(), expect)