diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu index d7286e08d5..64a4008ca9 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -343,6 +343,31 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre return; } +// double +template void Exponential(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Expm1(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Logarithm(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Log1p(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Erf(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Erfc(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Negative(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Reciprocal(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Square(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Sqrt(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Sin(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Cos(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Asin(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void ACos(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Atan(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Asinh(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Acosh(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Rsqrt(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Zeroslike(double *output, const size_t count, cudaStream_t cuda_stream); +template void Abs(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); +template void Floor(const double *input, double *output, const size_t count, cudaStream_t cuda_stream); + + +// float template void Exponential(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Expm1(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Logarithm(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); @@ -364,6 +389,8 @@ template void Rsqrt(const float *input, float *output, const size_t count template void Zeroslike(float *output, const size_t count, cudaStream_t cuda_stream); template void Abs(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Floor(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); + +// half template void Exponential(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Expm1(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Logarithm(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc index 91571c818b..89599fd637 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2020 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,6 +42,8 @@ MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpGpuKernel, double) MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), @@ -58,6 +60,8 @@ MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpGpuKernel, double) MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), @@ -66,6 +70,8 @@ MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpGpuKernel, double) MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), @@ -78,6 +84,8 @@ MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpGpuKernel, double) MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), @@ -94,6 +102,8 @@ MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + UnaryOpGpuKernel, double) MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), diff --git a/tests/st/ops/gpu/test_cos_op.py b/tests/st/ops/gpu/test_cos_op.py index 4feb1aec09..bc1cf9029e 100644 --- a/tests/st/ops/gpu/test_cos_op.py +++ b/tests/st/ops/gpu/test_cos_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,14 +20,29 @@ import mindspore.context as context from mindspore import Tensor from mindspore.ops import operations as P - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_cos(): - x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) +def cos(nptype): + np.random.seed(0) + x_np = np.random.rand(2, 3, 4, 4).astype(nptype) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") output_ms = P.Cos()(Tensor(x_np)) output_np = np.cos(x_np) assert np.allclose(output_ms.asnumpy(), output_np) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cos_float16(): + cos(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cos_float32(): + cos(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_cos_float64(): + cos(np.float64) diff --git a/tests/st/ops/gpu/test_loss.py b/tests/st/ops/gpu/test_loss.py index a5d95a9859..8f43bc1681 100644 --- a/tests/st/ops/gpu/test_loss.py +++ b/tests/st/ops/gpu/test_loss.py @@ -14,15 +14,14 @@ # ============================================================================ """ test loss """ import numpy as np -import mindspore +import pytest + from mindspore import Tensor from mindspore.ops import operations as P from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import L1Loss import mindspore.context as context -context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - class WeightedLoss(_Loss): def __init__(self, reduction='mean', weights=1.0): super(WeightedLoss, self).__init__(reduction) @@ -33,10 +32,13 @@ class WeightedLoss(_Loss): x = self.abs(base - target) return self.get_loss(x, self.weights) -def test_WeightedLoss(): + +def weighted_loss(nptype): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + loss = WeightedLoss() - input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) - target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) + input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype)) + target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype)) output_data = loss(input_data, target_data) error_range = np.ones(shape=output_data.shape) * 10e-6 @@ -50,14 +52,26 @@ def test_WeightedLoss(): diff = test_output - output_data * 3 assert np.all(abs(diff.asnumpy()) < error_range) - loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]]))) - y_true = Tensor(np.array([[0., 1.], [0., 0.]]), mindspore.float32) - y_pred = Tensor(np.array([[1., 1.], [1., 0.]]), mindspore.float32) + loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]]).astype(nptype))) + y_true = Tensor(np.array([[0., 1.], [0., 0.]]).astype(nptype)) + y_pred = Tensor(np.array([[1., 1.], [1., 0.]]).astype(nptype)) test_data = 0.35 output = loss(y_true, y_pred) diff = test_data - output.asnumpy() assert np.all(abs(diff) < error_range) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_weighted_loss_float32(): + weighted_loss(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_weighted_loss_float64(): + weighted_loss(np.float64) + class CustomLoss(_Loss): def __init__(self, reduction='mean'): super(CustomLoss, self).__init__(reduction) @@ -67,10 +81,10 @@ class CustomLoss(_Loss): x = self.abs(base - target) return self.get_loss(x, weights=2.0) -def test_CustomLoss(): +def custom_loss(nptype): loss = L1Loss() - input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) - target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) + input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype)) + target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype)) output_data = loss(input_data, target_data) error_range = np.ones(shape=output_data.shape) * 10e-6 @@ -78,3 +92,21 @@ def test_CustomLoss(): test_output = customloss(input_data, target_data) diff = test_output - output_data * 2.0 assert np.all(abs(diff.asnumpy()) < error_range) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_custom_loss_float16(): + custom_loss(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_custom_loss_float32(): + custom_loss(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_custom_loss_float64(): + custom_loss(np.float64) diff --git a/tests/st/ops/gpu/test_neg_op.py b/tests/st/ops/gpu/test_neg_op.py index 2f1d662bfc..baa5f7cb0f 100644 --- a/tests/st/ops/gpu/test_neg_op.py +++ b/tests/st/ops/gpu/test_neg_op.py @@ -1,4 +1,4 @@ -# Copyright 2019 Huawei Technologies Co., Ltd +# Copyright 2019-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,12 +31,9 @@ class NetNeg(nn.Cell): return self.neg(x) -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_neg(): - x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32) - x1_np = np.random.uniform(-2, 2, 1).astype(np.float32) +def neg(nptype): + x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype) + x1_np = np.random.uniform(-2, 2, 1).astype(nptype) x0 = Tensor(x0_np) x1 = Tensor(x1_np) expect0 = np.negative(x0_np) @@ -45,23 +42,41 @@ def test_neg(): error1 = np.ones(shape=expect1.shape) * 1.0e-5 context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") - neg = NetNeg() - output0 = neg(x0) + neg_net = NetNeg() + output0 = neg_net(x0) diff0 = output0.asnumpy() - expect0 assert np.all(diff0 < error0) assert output0.shape == expect0.shape - output1 = neg(x1) + output1 = neg_net(x1) diff1 = output1.asnumpy() - expect1 assert np.all(diff1 < error1) assert output1.shape == expect1.shape context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - neg = NetNeg() - output0 = neg(x0) + neg_net = NetNeg() + output0 = neg_net(x0) diff0 = output0.asnumpy() - expect0 assert np.all(diff0 < error0) assert output0.shape == expect0.shape - output1 = neg(x1) + output1 = neg_net(x1) diff1 = output1.asnumpy() - expect1 assert np.all(diff1 < error1) assert output1.shape == expect1.shape + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_neg_float16(): + neg(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_neg_float32(): + neg(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_neg_float64(): + neg(np.float64) diff --git a/tests/st/ops/gpu/test_sin_op.py b/tests/st/ops/gpu/test_sin_op.py index 117a7a8811..9f6fcb500c 100644 --- a/tests/st/ops/gpu/test_sin_op.py +++ b/tests/st/ops/gpu/test_sin_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,14 +20,29 @@ import mindspore.context as context from mindspore import Tensor from mindspore.ops import operations as P - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_sin(): - x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) +def sin(nptype): + np.random.seed(0) + x_np = np.random.rand(2, 3, 4, 4).astype(nptype) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") output_ms = P.Sin()(Tensor(x_np)) output_np = np.sin(x_np) assert np.allclose(output_ms.asnumpy(), output_np) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sin_float16(): + sin(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sin_float32(): + sin(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sin_float64(): + sin(np.float64) diff --git a/tests/st/ops/gpu/test_sqrt_op.py b/tests/st/ops/gpu/test_sqrt_op.py index 35c06d4c30..c64c2c4293 100644 --- a/tests/st/ops/gpu/test_sqrt_op.py +++ b/tests/st/ops/gpu/test_sqrt_op.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,18 +20,40 @@ import mindspore.context as context from mindspore import Tensor from mindspore.ops import operations as P - -@pytest.mark.level0 -@pytest.mark.platform_x86_gpu_training -@pytest.mark.env_onecard -def test_sqrt(): - x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) +def sqrt(nptype): + np.random.seed(0) + x_np = np.random.rand(2, 3, 4, 4).astype(nptype) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") output_ms = P.Sqrt()(Tensor(x_np)) output_np = np.sqrt(x_np) assert np.allclose(output_ms.asnumpy(), output_np) +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sqrt_float16(): + sqrt(np.float16) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sqrt_float32(): + sqrt(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sqrt_float64(): + sqrt(np.float64) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_rsqrt(): + np.random.seed(0) + x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) + output_ms = P.Rsqrt()(Tensor(x_np)) output_np = 1 / np.sqrt(x_np) assert np.allclose(output_ms.asnumpy(), output_np)