From 8132e56417fddcd0a4843bf82f3728b04e318b4c Mon Sep 17 00:00:00 2001 From: peixu_ren Date: Wed, 16 Sep 2020 14:31:22 -0400 Subject: [PATCH] Add dtype float16 that erf and erfc should support --- .../kernel_compiler/gpu/cuda_impl/erf_impl.cu | 3 ++- .../kernel_compiler/gpu/cuda_impl/erfc_impl.cu | 3 ++- .../kernel_compiler/gpu/math/erf_gpu_kernel.cc | 2 ++ .../kernel_compiler/gpu/math/erfc_gpu_kernel.cc | 2 ++ tests/st/ops/gpu/test_erf_op.py | 15 +++++++++++++-- tests/st/ops/gpu/test_erfc_op.py | 15 +++++++++++++-- 6 files changed, 34 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cu index 257c8503b7..931df0300e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erf_impl.cu @@ -18,7 +18,7 @@ template __global__ void ErfKernel(T *input, T *output, size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = (T)erf(input[i]); + output[i] = static_cast(erf(static_cast(input[i]))); } return; } @@ -30,3 +30,4 @@ void Erf(T *input, T *output, size_t count, cudaStream_t cuda_stream) { } template void Erf(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Erf(half *input, half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cu index 6b20cd5537..1e341eba43 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/erfc_impl.cu @@ -18,7 +18,7 @@ template __global__ void ErfcKernel(T *input, T *output, size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { - output[i] = (T)erfc(input[i]); + output[i] = static_cast(erfc(static_cast(input[i]))); } return; } @@ -30,3 +30,4 @@ void Erfc(T *input, T *output, size_t count, cudaStream_t cuda_stream) { } template void Erfc(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Erfc(half *input, half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.cc index 3531e9ccca..adf5286fef 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erf_gpu_kernel.cc @@ -20,5 +20,7 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ErfGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Erf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ErfGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.cc index cb63ed6f7f..6725bffbd2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/erfc_gpu_kernel.cc @@ -20,5 +20,7 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ErfcGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + ErfcGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/gpu/test_erf_op.py b/tests/st/ops/gpu/test_erf_op.py index 98c2085137..93188bbc6c 100644 --- a/tests/st/ops/gpu/test_erf_op.py +++ b/tests/st/ops/gpu/test_erf_op.py @@ -37,10 +37,21 @@ class NetErf(nn.Cell): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_exp(): +def test_erf_fp32(): erf = NetErf() - x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + x = np.random.rand(3, 8).astype(np.float32) output = erf(Tensor(x, dtype=dtype.float32)) expect = special.erf(x) tol = 1e-6 assert (np.abs(output.asnumpy() - expect) < tol).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_erf_fp16(): + erf = NetErf() + x = np.random.rand(3, 8).astype(np.float16) + output = erf(Tensor(x, dtype=dtype.float16)) + expect = special.erf(x) + tol = 1e-3 + assert (np.abs(output.asnumpy() - expect) < tol).all() diff --git a/tests/st/ops/gpu/test_erfc_op.py b/tests/st/ops/gpu/test_erfc_op.py index 38be92bf9c..01bdbbc0a4 100644 --- a/tests/st/ops/gpu/test_erfc_op.py +++ b/tests/st/ops/gpu/test_erfc_op.py @@ -37,10 +37,21 @@ class NetErfc(nn.Cell): @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard -def test_exp(): +def test_erfc_fp32(): erfc = NetErfc() - x = np.array([2.0, 3.0, 4.0, 5.0]).astype(np.float32) + x = np.random.rand(3, 8).astype(np.float32) output = erfc(Tensor(x, dtype=dtype.float32)) expect = special.erfc(x) tol = 1e-6 assert (np.abs(output.asnumpy() - expect) < tol).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_erfc_fp16(): + erfc = NetErfc() + x = np.random.rand(3, 8).astype(np.float16) + output = erfc(Tensor(x, dtype=dtype.float16)) + expect = special.erfc(x) + tol = 1e-3 + assert (np.abs(output.asnumpy() - expect) < tol).all()