From 8481fd59d898f616661c055815f98a424cced6ec Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Tue, 27 Oct 2020 10:35:01 +0800 Subject: [PATCH] Add supports to ACosGrad and AsinGrad --- .../gpu/cuda_impl/unary_op_grad_impl.cu | 60 ++++++++++++++++++- .../gpu/cuda_impl/unary_op_grad_impl.cuh | 4 ++ .../gpu/math/unary_op_grad_gpu_kernel.cc | 16 +++++ .../gpu/math/unary_op_grad_gpu_kernel.h | 22 ++++++- tests/st/ops/gpu/test_acos_grad_op.py | 46 ++++++++++++++ tests/st/ops/gpu/test_asin_grad_op.py | 46 ++++++++++++++ 6 files changed, 191 insertions(+), 3 deletions(-) create mode 100644 tests/st/ops/gpu/test_acos_grad_op.py create mode 100644 tests/st/ops/gpu/test_asin_grad_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu index 32eb1684cf..e4c8d8a47f 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu @@ -15,6 +15,7 @@ */ #include "unary_op_grad_impl.cuh" + template __global__ void SqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -36,7 +37,44 @@ __global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const } return; } - +template +__global__ void AsinGradKernel(const T *input, const T *dout, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + T one = 1; + T sqt = sqrtf(one - input[i] * input[i]); + output[i] = dout[i] / sqt; + } + return; +} +template <> +__global__ void AsinGradKernel(const half *input, const half *dout, half *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + half one = 1; + half sqt = hsqrt(one - input[i] * input[i]); + output[i] = dout[i] / sqt; + } + return; +} +template +__global__ void ACosGradKernel(const T *input, const T *dout, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + T neg_one = -1; + T one = 1; + T sqt = sqrtf(one - input[i] * input[i]); + output[i] = neg_one * dout[i] / sqt; + } + return; +} +template <> +__global__ void ACosGradKernel(const half *input, const half *dout, half *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + half neg_one = -1; + half one = 1; + half sqt = hsqrt(one - input[i] * input[i]); + output[i] = neg_one * dout[i] / sqt; + } + return; +} template void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { SqrtGradKernel<<>>(input, dout, output, count); @@ -48,11 +86,31 @@ void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cud return; } +template +void AsinGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + AsinGradKernel<<>>(input, dout, output, count); + return; +} + +template +void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + ACosGradKernel<<>>(input, dout, output, count); + return; +} + template void SqrtGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); template void RsqrtGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); +template void AsinGrad(const float *input, const float *dout, float *output, const size_t count, + cudaStream_t cuda_stream); +template void ACosGrad(const float *input, const float *dout, float *output, const size_t count, + cudaStream_t cuda_stream); template void SqrtGrad(const half *input, const half *dout, half *output, const size_t count, cudaStream_t cuda_stream); template void RsqrtGrad(const half *input, const half *dout, half *output, const size_t count, cudaStream_t cuda_stream); +template void AsinGrad(const half *input, const half *dout, half *output, const size_t count, + cudaStream_t cuda_stream); +template void ACosGrad(const half *input, const half *dout, half *output, const size_t count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh index 61256ac73a..c5aaaf278e 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh @@ -22,5 +22,9 @@ template void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); template void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); +template +void AsinGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); +template +void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc index 43c5334c2c..308f8994ac 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc @@ -34,5 +34,21 @@ MS_REG_GPU_KERNEL_ONE( RsqrtGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryGradOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + AsinGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryGradOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AsinGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryGradOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + ACosGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryGradOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + ACosGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryGradOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h index e78676fd01..23ee69e1f6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h @@ -27,9 +27,17 @@ namespace mindspore { namespace kernel { -enum UnaryGradOptype { UNARY_OP_SQRT_GRAD = 0, UNARY_OP_RSQRT_GRAD, UNARY_OP_GRAD_INVALID_TYPE = 255 }; +enum UnaryGradOptype { + UNARY_OP_SQRT_GRAD = 0, + UNARY_OP_RSQRT_GRAD = 1, + UNARY_OP_ASIN_GRAD = 2, + UNARY_OP_ACOS_GRAD = 3, + UNARY_OP_GRAD_INVALID_TYPE = 255 +}; static const std::map kUnaryGradOpTypeMap = {{"SqrtGrad", UNARY_OP_SQRT_GRAD}, - {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}}; + {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, + {"AsinGrad", UNARY_OP_ASIN_GRAD}, + {"ACosGrad", UNARY_OP_ACOS_GRAD}}; template class UnaryGradOpGpuKernel : public GpuKernel { public: @@ -59,6 +67,16 @@ class UnaryGradOpGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_ASIN_GRAD: { + AsinGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_ACOS_GRAD: { + ACosGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } case UNARY_OP_RSQRT_GRAD: { RsqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); diff --git a/tests/st/ops/gpu/test_acos_grad_op.py b/tests/st/ops/gpu/test_acos_grad_op.py new file mode 100644 index 0000000000..c1f5695306 --- /dev/null +++ b/tests/st/ops/gpu/test_acos_grad_op.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations._grad_ops as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acosgrad_fp32(): + error = np.ones(4) * 1.0e-7 + x_np = np.array([0, -0.25, 0.5, 0.3]).astype(np.float32) + dout_np = np.array([1, 1, 1, 1]).astype(np.float32) + output_ms = P.ACosGrad()(Tensor(x_np), Tensor(dout_np)) + expect = np.array([-1, -1.0327955, -1.1547005, -1.0482849]) + diff = output_ms.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acosgrad_fp16(): + error = np.ones(4) * 1.0e-3 + x_np = np.array([0, -0.25, 0.5, 0.3]).astype(np.float16) + dout_np = np.array([1, 1, 1, 1]).astype(np.float16) + output_ms = P.ACosGrad()(Tensor(x_np), Tensor(dout_np)) + expect = np.array([-1, -1.033, -1.154, -1.048]) + diff = output_ms.asnumpy() - expect + assert np.all(diff < error) diff --git a/tests/st/ops/gpu/test_asin_grad_op.py b/tests/st/ops/gpu/test_asin_grad_op.py new file mode 100644 index 0000000000..5d3606fb66 --- /dev/null +++ b/tests/st/ops/gpu/test_asin_grad_op.py @@ -0,0 +1,46 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations._grad_ops as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asingrad_fp32(): + error = np.ones(4) * 1.0e-7 + x_np = np.array([0, -0.25, 0.5, 0.3]).astype(np.float32) + dout_np = np.array([1, 1, 1, 1]).astype(np.float32) + output_ms = P.AsinGrad()(Tensor(x_np), Tensor(dout_np)) + expect = np.array([1, 1.0327955, 1.1547005, 1.0482849]) + diff = output_ms.asnumpy() - expect + assert np.all(diff < error) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asingrad_fp16(): + error = np.ones(4) * 1.0e-3 + x_np = np.array([0, -0.25, 0.5, 0.3]).astype(np.float16) + dout_np = np.array([1, 1, 1, 1]).astype(np.float16) + output_ms = P.AsinGrad()(Tensor(x_np), Tensor(dout_np)) + expect = np.array([1, 1.033, 1.154, 1.048]) + diff = output_ms.asnumpy() - expect + assert np.all(diff < error)