!7879 fix bug that asin/acos not support fp16 on gpu

Merge pull request !7879 from zhouyuanshen/asin_acos_fp16
pull/7879/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit f3e8798b40

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -111,7 +111,9 @@ __global__ void SinKernel(const half *input, half *output, const size_t count) {
template <typename T>
__global__ void AsinKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = asinf(input[i]);
float inputf = static_cast<float>(input[i]);
T res = static_cast<T>(asinf(inputf));
output[i] = res;
}
return;
}
@ -132,7 +134,9 @@ __global__ void CosKernel(const half *input, half *output, const size_t count) {
template <typename T>
__global__ void ACosKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = acosf(input[i]);
float inputf = static_cast<float>(input[i]);
T res = static_cast<T>(acosf(inputf));
output[i] = res;
}
return;
}

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -56,12 +56,16 @@ MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutp
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

@ -29,3 +29,12 @@ def test_acos_fp32():
output_ms = P.ACos()(Tensor(x_np))
output_np = np.arccos(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_acos_fp16():
x_np = np.array([0.74, 0.04, 0.30, 0.56]).astype(np.float16)
output_ms = P.ACos()(Tensor(x_np))
output_np = np.arccos(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)

@ -29,3 +29,12 @@ def test_asin_fp32():
output_ms = P.Asin()(Tensor(x_np))
output_np = np.arcsin(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asin_fp16():
x_np = np.array([0.74, 0.04, 0.30, 0.56]).astype(np.float16)
output_ms = P.Asin()(Tensor(x_np))
output_np = np.arcsin(x_np)
assert np.allclose(output_ms.asnumpy(), output_np)

Loading…
Cancel
Save