!12362 Add float64 support to Abs, Neg, Sqrt, Sin, Cos

From: @peilin-wang
Reviewed-by: @robingrosman,@tom__chen
Signed-off-by: @robingrosman
pull/12362/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 4b27c3206d

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -343,6 +343,31 @@ void Floor(const T *input, T *output, const size_t count, cudaStream_t cuda_stre
return; return;
} }
// double
template void Exponential<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Expm1<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Log1p<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Erf<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Erfc<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Negative<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Reciprocal<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Square<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Sqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Sin<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Cos<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Asin<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void ACos<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Atan<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Asinh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Acosh<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<double>(double *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<double>(const double *input, double *output, const size_t count, cudaStream_t cuda_stream);
// float
template void Exponential<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Exponential<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Expm1<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Expm1<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Logarithm<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
@ -364,6 +389,8 @@ template void Rsqrt<float>(const float *input, float *output, const size_t count
template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream); template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream); template void Floor<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
// half
template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Exponential<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Expm1<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Expm1<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Logarithm<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream); template void Logarithm<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

@ -1,5 +1,5 @@
/** /**
* Copyright 2019-2020 Huawei Technologies Co., Ltd * Copyright 2019-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -42,6 +42,8 @@ MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Erfc, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Neg, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -58,6 +60,8 @@ MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -66,6 +70,8 @@ MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -78,6 +84,8 @@ MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOu
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -94,6 +102,8 @@ MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half) UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
UnaryOpGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float) UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,14 +20,29 @@ import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
def cos(nptype):
@pytest.mark.level0 np.random.seed(0)
@pytest.mark.platform_x86_gpu_training x_np = np.random.rand(2, 3, 4, 4).astype(nptype)
@pytest.mark.env_onecard
def test_cos():
x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_ms = P.Cos()(Tensor(x_np)) output_ms = P.Cos()(Tensor(x_np))
output_np = np.cos(x_np) output_np = np.cos(x_np)
assert np.allclose(output_ms.asnumpy(), output_np) assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cos_float16():
cos(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cos_float32():
cos(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_cos_float64():
cos(np.float64)

@ -14,15 +14,14 @@
# ============================================================================ # ============================================================================
""" test loss """ """ test loss """
import numpy as np import numpy as np
import mindspore import pytest
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import L1Loss from mindspore.nn.loss.loss import L1Loss
import mindspore.context as context import mindspore.context as context
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
class WeightedLoss(_Loss): class WeightedLoss(_Loss):
def __init__(self, reduction='mean', weights=1.0): def __init__(self, reduction='mean', weights=1.0):
super(WeightedLoss, self).__init__(reduction) super(WeightedLoss, self).__init__(reduction)
@ -33,10 +32,13 @@ class WeightedLoss(_Loss):
x = self.abs(base - target) x = self.abs(base - target)
return self.get_loss(x, self.weights) return self.get_loss(x, self.weights)
def test_WeightedLoss():
def weighted_loss(nptype):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
loss = WeightedLoss() loss = WeightedLoss()
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype))
output_data = loss(input_data, target_data) output_data = loss(input_data, target_data)
error_range = np.ones(shape=output_data.shape) * 10e-6 error_range = np.ones(shape=output_data.shape) * 10e-6
@ -50,14 +52,26 @@ def test_WeightedLoss():
diff = test_output - output_data * 3 diff = test_output - output_data * 3
assert np.all(abs(diff.asnumpy()) < error_range) assert np.all(abs(diff.asnumpy()) < error_range)
loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]]))) loss = WeightedLoss(weights=Tensor(np.array([[0.7, 0.3], [0.7, 0.3]]).astype(nptype)))
y_true = Tensor(np.array([[0., 1.], [0., 0.]]), mindspore.float32) y_true = Tensor(np.array([[0., 1.], [0., 0.]]).astype(nptype))
y_pred = Tensor(np.array([[1., 1.], [1., 0.]]), mindspore.float32) y_pred = Tensor(np.array([[1., 1.], [1., 0.]]).astype(nptype))
test_data = 0.35 test_data = 0.35
output = loss(y_true, y_pred) output = loss(y_true, y_pred)
diff = test_data - output.asnumpy() diff = test_data - output.asnumpy()
assert np.all(abs(diff) < error_range) assert np.all(abs(diff) < error_range)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_weighted_loss_float32():
weighted_loss(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_weighted_loss_float64():
weighted_loss(np.float64)
class CustomLoss(_Loss): class CustomLoss(_Loss):
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
super(CustomLoss, self).__init__(reduction) super(CustomLoss, self).__init__(reduction)
@ -67,10 +81,10 @@ class CustomLoss(_Loss):
x = self.abs(base - target) x = self.abs(base - target)
return self.get_loss(x, weights=2.0) return self.get_loss(x, weights=2.0)
def test_CustomLoss(): def custom_loss(nptype):
loss = L1Loss() loss = L1Loss()
input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(nptype))
target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(nptype))
output_data = loss(input_data, target_data) output_data = loss(input_data, target_data)
error_range = np.ones(shape=output_data.shape) * 10e-6 error_range = np.ones(shape=output_data.shape) * 10e-6
@ -78,3 +92,21 @@ def test_CustomLoss():
test_output = customloss(input_data, target_data) test_output = customloss(input_data, target_data)
diff = test_output - output_data * 2.0 diff = test_output - output_data * 2.0
assert np.all(abs(diff.asnumpy()) < error_range) assert np.all(abs(diff.asnumpy()) < error_range)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_custom_loss_float16():
custom_loss(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_custom_loss_float32():
custom_loss(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_custom_loss_float64():
custom_loss(np.float64)

@ -1,4 +1,4 @@
# Copyright 2019 Huawei Technologies Co., Ltd # Copyright 2019-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -31,12 +31,9 @@ class NetNeg(nn.Cell):
return self.neg(x) return self.neg(x)
@pytest.mark.level0 def neg(nptype):
@pytest.mark.platform_x86_gpu_training x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(nptype)
@pytest.mark.env_onecard x1_np = np.random.uniform(-2, 2, 1).astype(nptype)
def test_neg():
x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32)
x1_np = np.random.uniform(-2, 2, 1).astype(np.float32)
x0 = Tensor(x0_np) x0 = Tensor(x0_np)
x1 = Tensor(x1_np) x1 = Tensor(x1_np)
expect0 = np.negative(x0_np) expect0 = np.negative(x0_np)
@ -45,23 +42,41 @@ def test_neg():
error1 = np.ones(shape=expect1.shape) * 1.0e-5 error1 = np.ones(shape=expect1.shape) * 1.0e-5
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
neg = NetNeg() neg_net = NetNeg()
output0 = neg(x0) output0 = neg_net(x0)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape == expect0.shape assert output0.shape == expect0.shape
output1 = neg(x1) output1 = neg_net(x1)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape == expect1.shape assert output1.shape == expect1.shape
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
neg = NetNeg() neg_net = NetNeg()
output0 = neg(x0) output0 = neg_net(x0)
diff0 = output0.asnumpy() - expect0 diff0 = output0.asnumpy() - expect0
assert np.all(diff0 < error0) assert np.all(diff0 < error0)
assert output0.shape == expect0.shape assert output0.shape == expect0.shape
output1 = neg(x1) output1 = neg_net(x1)
diff1 = output1.asnumpy() - expect1 diff1 = output1.asnumpy() - expect1
assert np.all(diff1 < error1) assert np.all(diff1 < error1)
assert output1.shape == expect1.shape assert output1.shape == expect1.shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_neg_float16():
neg(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_neg_float32():
neg(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_neg_float64():
neg(np.float64)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,14 +20,29 @@ import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
def sin(nptype):
@pytest.mark.level0 np.random.seed(0)
@pytest.mark.platform_x86_gpu_training x_np = np.random.rand(2, 3, 4, 4).astype(nptype)
@pytest.mark.env_onecard
def test_sin():
x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_ms = P.Sin()(Tensor(x_np)) output_ms = P.Sin()(Tensor(x_np))
output_np = np.sin(x_np) output_np = np.sin(x_np)
assert np.allclose(output_ms.asnumpy(), output_np) assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sin_float16():
sin(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sin_float32():
sin(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sin_float64():
sin(np.float64)

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd # Copyright 2020-2021 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -20,18 +20,40 @@ import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
def sqrt(nptype):
@pytest.mark.level0 np.random.seed(0)
@pytest.mark.platform_x86_gpu_training x_np = np.random.rand(2, 3, 4, 4).astype(nptype)
@pytest.mark.env_onecard
def test_sqrt():
x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
output_ms = P.Sqrt()(Tensor(x_np)) output_ms = P.Sqrt()(Tensor(x_np))
output_np = np.sqrt(x_np) output_np = np.sqrt(x_np)
assert np.allclose(output_ms.asnumpy(), output_np) assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sqrt_float16():
sqrt(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sqrt_float32():
sqrt(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sqrt_float64():
sqrt(np.float64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_rsqrt():
np.random.seed(0)
x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)
output_ms = P.Rsqrt()(Tensor(x_np)) output_ms = P.Rsqrt()(Tensor(x_np))
output_np = 1 / np.sqrt(x_np) output_np = 1 / np.sqrt(x_np)
assert np.allclose(output_ms.asnumpy(), output_np) assert np.allclose(output_ms.asnumpy(), output_np)

Loading…
Cancel
Save