From 1804c042ec7d5bbae23fb05a2a5869a1ec2c4d23 Mon Sep 17 00:00:00 2001 From: zhouyuanshen Date: Wed, 28 Oct 2020 09:37:24 +0800 Subject: [PATCH] fix bug that asin/acos not support fp16 on gpu --- .../kernel_compiler/gpu/cuda_impl/unary_op_impl.cu | 10 +++++++--- .../kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh | 2 +- .../kernel_compiler/gpu/math/unary_op_gpu_kernel.cc | 6 +++++- .../kernel_compiler/gpu/math/unary_op_gpu_kernel.h | 2 +- tests/st/ops/gpu/test_acos_op.py | 9 +++++++++ tests/st/ops/gpu/test_asin_op.py | 9 +++++++++ 6 files changed, 32 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index aee6939eb5..de51a843a0 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -111,7 +111,9 @@ __global__ void SinKernel(const half *input, half *output, const size_t count) { template __global__ void AsinKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = asinf(input[i]); + float inputf = static_cast(input[i]); + T res = static_cast(asinf(inputf)); + output[i] = res; } return; } @@ -132,7 +134,9 @@ __global__ void CosKernel(const half *input, half *output, const size_t count) { template __global__ void ACosKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = acosf(input[i]); + float inputf = static_cast(input[i]); + T res = static_cast(acosf(inputf)); + output[i] = res; } return; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index 3556e45cd5..30f845c9bd 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index 3c5078f114..560cd55890 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,12 +56,16 @@ MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutp UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index d9332d24f3..620b2e17bd 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tests/st/ops/gpu/test_acos_op.py b/tests/st/ops/gpu/test_acos_op.py index 724db6b4a1..3d27de6fa0 100644 --- a/tests/st/ops/gpu/test_acos_op.py +++ b/tests/st/ops/gpu/test_acos_op.py @@ -29,3 +29,12 @@ def test_acos_fp32(): output_ms = P.ACos()(Tensor(x_np)) output_np = np.arccos(x_np) assert np.allclose(output_ms.asnumpy(), output_np) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acos_fp16(): + x_np = np.array([0.74, 0.04, 0.30, 0.56]).astype(np.float16) + output_ms = P.ACos()(Tensor(x_np)) + output_np = np.arccos(x_np) + assert np.allclose(output_ms.asnumpy(), output_np) diff --git a/tests/st/ops/gpu/test_asin_op.py b/tests/st/ops/gpu/test_asin_op.py index c781996dfa..b034a225b5 100644 --- a/tests/st/ops/gpu/test_asin_op.py +++ b/tests/st/ops/gpu/test_asin_op.py @@ -29,3 +29,12 @@ def test_asin_fp32(): output_ms = P.Asin()(Tensor(x_np)) output_np = np.arcsin(x_np) assert np.allclose(output_ms.asnumpy(), output_np) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asin_fp16(): + x_np = np.array([0.74, 0.04, 0.30, 0.56]).astype(np.float16) + output_ms = P.Asin()(Tensor(x_np)) + output_np = np.arcsin(x_np) + assert np.allclose(output_ms.asnumpy(), output_np)