diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu index e4c8d8a47f..d3523ecf6b 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cu @@ -76,6 +76,33 @@ __global__ void ACosGradKernel(const half *input, const half *dout, half *output return; } template +__global__ void AtanGradKernel(const T *input, const T *dout, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + T one = 1; + T divisor = one + input[i] * input[i]; + output[i] = dout[i] / divisor; + } + return; +} +template +__global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + float inputf = static_cast(input[i]); + T coshy = static_cast(coshf(inputf)); + output[i] = dout[i] / coshy; + } + return; +} +template +__global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + float inputf = static_cast(input[i]); + T sinhy = static_cast(sinhf(inputf)); + output[i] = dout[i] / sinhy; + } + return; +} +template void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { SqrtGradKernel<<>>(input, dout, output, count); return; @@ -98,6 +125,24 @@ void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cuda return; } +template +void AtanGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + AtanGradKernel<<>>(input, dout, output, count); + return; +} + +template +void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + AsinhGradKernel<<>>(input, dout, output, count); + return; +} + +template +void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) { + AcoshGradKernel<<>>(input, dout, output, count); + return; +} + template void SqrtGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); template void RsqrtGrad(const float *input, const float *dout, float *output, const size_t count, @@ -106,6 +151,12 @@ template void AsinGrad(const float *input, const float *dout, float *outp cudaStream_t cuda_stream); template void ACosGrad(const float *input, const float *dout, float *output, const size_t count, cudaStream_t cuda_stream); +template void AtanGrad(const float *input, const float *dout, float *output, const size_t count, + cudaStream_t cuda_stream); +template void AsinhGrad(const float *input, const float *dout, float *output, const size_t count, + cudaStream_t cuda_stream); +template void AcoshGrad(const float *input, const float *dout, float *output, const size_t count, + cudaStream_t cuda_stream); template void SqrtGrad(const half *input, const half *dout, half *output, const size_t count, cudaStream_t cuda_stream); template void RsqrtGrad(const half *input, const half *dout, half *output, const size_t count, @@ -114,3 +165,9 @@ template void AsinGrad(const half *input, const half *dout, half *output, cudaStream_t cuda_stream); template void ACosGrad(const half *input, const half *dout, half *output, const size_t count, cudaStream_t cuda_stream); +template void AtanGrad(const half *input, const half *dout, half *output, const size_t count, + cudaStream_t cuda_stream); +template void AsinhGrad(const half *input, const half *dout, half *output, const size_t count, + cudaStream_t cuda_stream); +template void AcoshGrad(const half *input, const half *dout, half *output, const size_t count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh index c5aaaf278e..8e636d5bd2 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_grad_impl.cuh @@ -26,5 +26,12 @@ template void AsinGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); template void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); +template +void AtanGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); +template +void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); +template +void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream); + #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index 890124d6ae..d7286e08d5 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -146,6 +146,15 @@ __global__ void AsinKernel(const T *input, T *output, const size_t count) { return; } template +__global__ void AsinhKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + float inputf = static_cast(input[i]); + T res = static_cast(asinhf(inputf)); + output[i] = res; + } + return; +} +template __global__ void CosKernel(const T *input, T *output, const size_t count) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { output[i] = cos(input[i]); @@ -169,6 +178,24 @@ __global__ void ACosKernel(const T *input, T *output, const size_t count) { return; } template +__global__ void AcoshKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + float inputf = static_cast(input[i]); + T res = static_cast(acoshf(inputf)); + output[i] = res; + } + return; +} +template +__global__ void AtanKernel(const T *input, T *output, const size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + float inputf = static_cast(input[i]); + T res = static_cast(atanf(inputf)); + output[i] = res; + } + return; +} +template __global__ void ZeroslikeKernel(T *output, const size_t count) { T zero = 0.0; for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -281,6 +308,21 @@ void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_strea return; } template +void Atan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + AtanKernel<<>>(input, output, count); + return; +} +template +void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + AsinhKernel<<>>(input, output, count); + return; +} +template +void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { + AcoshKernel<<>>(input, output, count); + return; +} +template void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) { RsqrtKernel<<>>(input, output, count); return; @@ -315,6 +357,9 @@ template void Sin(const float *input, float *output, const size_t count, template void Cos(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Asin(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void ACos(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); +template void Atan(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); +template void Asinh(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); +template void Acosh(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Rsqrt(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Zeroslike(float *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); @@ -333,6 +378,9 @@ template void Sin(const half *input, half *output, const size_t count, cud template void Cos(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Asin(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void ACos(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); +template void Atan(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); +template void Asinh(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); +template void Acosh(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Rsqrt(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Zeroslike(half *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh index 1bb94e07e2..e0347a1d93 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh @@ -49,6 +49,12 @@ void Asin(const T *input, T *output, const size_t count, cudaStream_t cuda_strea template void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); template +void Atan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); +template +void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); +template +void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); +template void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index 7def6ab7f0..91571c818b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -74,6 +74,10 @@ MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), @@ -82,6 +86,14 @@ MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h index bf94c0bfcd..29c4d12b89 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h @@ -44,6 +44,9 @@ enum UnaryOptype { UNARY_OP_COS, UNARY_OP_ASIN, UNARY_OP_ACOS, + UNARY_OP_ATAN, + UNARY_OP_ASINH, + UNARY_OP_ACOSH, UNARY_OP_ABS, UNARY_OP_FLOOR, UNARY_OP_INVALID_TYPE = 255 @@ -64,6 +67,9 @@ static const std::map kUnaryOpTypeMap = {{"Exp", UNARY {"Cos", UNARY_OP_COS}, {"Asin", UNARY_OP_ASIN}, {"ACos", UNARY_OP_ACOS}, + {"Atan", UNARY_OP_ATAN}, + {"Asinh", UNARY_OP_ASINH}, + {"Acosh", UNARY_OP_ACOSH}, {"Abs", UNARY_OP_ABS}, {"Floor", UNARY_OP_FLOOR}}; template @@ -142,6 +148,18 @@ class UnaryOpGpuKernel : public GpuKernel { ACos(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_ATAN: { + Atan(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_ASINH: { + Asinh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_ACOSH: { + Acosh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } case UNARY_OP_ZEROSLIKE: { Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); return true; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc index 308f8994ac..17982ca507 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.cc @@ -50,5 +50,29 @@ MS_REG_GPU_KERNEL_ONE( ACosGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryGradOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + AtanGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryGradOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AtanGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryGradOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + AsinhGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryGradOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AsinhGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryGradOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + AcoshGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryGradOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + AcoshGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryGradOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h index 23ee69e1f6..0a93d7473d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_grad_gpu_kernel.h @@ -32,12 +32,16 @@ enum UnaryGradOptype { UNARY_OP_RSQRT_GRAD = 1, UNARY_OP_ASIN_GRAD = 2, UNARY_OP_ACOS_GRAD = 3, + UNARY_OP_ATAN_GRAD = 4, + UNARY_OP_ASINH_GRAD = 5, + UNARY_OP_ACOSH_GRAD = 6, UNARY_OP_GRAD_INVALID_TYPE = 255 }; -static const std::map kUnaryGradOpTypeMap = {{"SqrtGrad", UNARY_OP_SQRT_GRAD}, - {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, - {"AsinGrad", UNARY_OP_ASIN_GRAD}, - {"ACosGrad", UNARY_OP_ACOS_GRAD}}; +static const std::map kUnaryGradOpTypeMap = { + {"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, {"AsinGrad", UNARY_OP_ASIN_GRAD}, + {"ACosGrad", UNARY_OP_ACOS_GRAD}, {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD}, + {"AcoshGrad", UNARY_OP_ACOSH_GRAD}}; + template class UnaryGradOpGpuKernel : public GpuKernel { public: @@ -77,6 +81,21 @@ class UnaryGradOpGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_ATAN_GRAD: { + AtanGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_ASINH_GRAD: { + AsinhGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_ACOSH_GRAD: { + AcoshGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } case UNARY_OP_RSQRT_GRAD: { RsqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); diff --git a/tests/st/ops/gpu/test_acosh_grad_op.py b/tests/st/ops/gpu/test_acosh_grad_op.py new file mode 100644 index 0000000000..2d4e015d40 --- /dev/null +++ b/tests/st/ops/gpu/test_acosh_grad_op.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations._grad_ops as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") +np.random.seed(1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acoshgrad_fp32(): + y_np = np.random.rand(4, 2).astype(np.float32) * 10 + dout_np = np.random.rand(4, 2).astype(np.float32) * 10 + output_ms = P.AcoshGrad()(Tensor(y_np), Tensor(dout_np)) + output_np = dout_np / np.sinh(y_np) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acoshgrad_fp16(): + y_np = np.random.rand(4, 2).astype(np.float16) * 10 + dout_np = np.random.rand(4, 2).astype(np.float16) * 10 + output_ms = P.AcoshGrad()(Tensor(y_np), Tensor(dout_np)) + output_np = dout_np.astype(np.float32) / np.sinh(y_np).astype(np.float32) + assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3) diff --git a/tests/st/ops/gpu/test_acosh_op.py b/tests/st/ops/gpu/test_acosh_op.py new file mode 100644 index 0000000000..69f7b72707 --- /dev/null +++ b/tests/st/ops/gpu/test_acosh_op.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") +np.random.seed(1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acosh_fp32(): + x_np = np.random.rand(4, 2).astype(np.float32) * 10 + 1 + output_ms = P.Acosh()(Tensor(x_np)) + output_np = np.arccosh(x_np) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acosh_fp16(): + x_np = np.random.rand(4, 2).astype(np.float16) * 10 + 1 + output_ms = P.Acosh()(Tensor(x_np)) + output_np = np.arccosh(x_np.astype(np.float32)).astype(np.float16) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3) diff --git a/tests/st/ops/gpu/test_asinh_grad_op.py b/tests/st/ops/gpu/test_asinh_grad_op.py new file mode 100644 index 0000000000..a42d820386 --- /dev/null +++ b/tests/st/ops/gpu/test_asinh_grad_op.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations._grad_ops as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") +np.random.seed(1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asinhgrad_fp32(): + y_np = np.random.rand(4, 2).astype(np.float32) * 10 + dout_np = np.random.rand(4, 2).astype(np.float32) * 10 + output_ms = P.AsinhGrad()(Tensor(y_np), Tensor(dout_np)) + output_np = dout_np / np.cosh(y_np) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asinhgrad_fp16(): + y_np = np.random.rand(4, 2).astype(np.float16) * 10 + dout_np = np.random.rand(4, 2).astype(np.float16) * 10 + output_ms = P.AsinhGrad()(Tensor(y_np), Tensor(dout_np)) + output_np = dout_np.astype(np.float32) / np.cosh(y_np).astype(np.float32) + assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3) diff --git a/tests/st/ops/gpu/test_asinh_op.py b/tests/st/ops/gpu/test_asinh_op.py new file mode 100644 index 0000000000..21b1fceece --- /dev/null +++ b/tests/st/ops/gpu/test_asinh_op.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") +np.random.seed(1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asinh_fp32(): + x_np = np.random.rand(4, 2).astype(np.float32) * 10 + output_ms = P.Asinh()(Tensor(x_np)) + output_np = np.arcsinh(x_np) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asinh_fp16(): + x_np = np.random.rand(4, 2).astype(np.float16) * 10 + output_ms = P.Asinh()(Tensor(x_np)) + output_np = np.arcsinh(x_np.astype(np.float32)).astype(np.float16) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3) diff --git a/tests/st/ops/gpu/test_atan_grad_op.py b/tests/st/ops/gpu/test_atan_grad_op.py new file mode 100644 index 0000000000..986709734b --- /dev/null +++ b/tests/st/ops/gpu/test_atan_grad_op.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +import mindspore.ops.operations._grad_ops as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") +np.random.seed(1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_atangrad_fp32(): + x_np = np.random.rand(4, 2).astype(np.float32) * 10 + dout_np = np.random.rand(4, 2).astype(np.float32) * 10 + output_ms = P.AtanGrad()(Tensor(x_np), Tensor(dout_np)) + output_np = dout_np / (1 + np.square(x_np)) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_atangrad_fp16(): + x_np = np.random.rand(4, 2).astype(np.float16) * 10 + dout_np = np.random.rand(4, 2).astype(np.float16) * 10 + output_ms = P.AtanGrad()(Tensor(x_np), Tensor(dout_np)) + output_np = dout_np.astype(np.float32) / (1 + np.square(x_np.astype(np.float32))) + assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3) diff --git a/tests/st/ops/gpu/test_atan_op.py b/tests/st/ops/gpu/test_atan_op.py new file mode 100644 index 0000000000..ea1ca25a89 --- /dev/null +++ b/tests/st/ops/gpu/test_atan_op.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops import operations as P +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") +np.random.seed(1) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_atan_fp32(): + x_np = np.random.rand(4, 2).astype(np.float32) * 10 + output_ms = P.Atan()(Tensor(x_np)) + output_np = np.arctan(x_np) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_atan_fp16(): + x_np = np.random.rand(4, 2).astype(np.float16) * 10 + output_ms = P.Atan()(Tensor(x_np)) + output_np = np.arctan(x_np.astype(np.float32)).astype(np.float16) + assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)