Add new GPU operators: Asinh, AsinhGrad, Acosh, AcoshGrad, Atan and AtanGrad

pull/10580/head
hedongdong 4 years ago
parent 3159fb462c
commit 352a362878

@ -76,6 +76,33 @@ __global__ void ACosGradKernel(const half *input, const half *dout, half *output
return;
}
template <typename T>
__global__ void AtanGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
T one = 1;
T divisor = one + input[i] * input[i];
output[i] = dout[i] / divisor;
}
return;
}
template <typename T>
__global__ void AsinhGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
float inputf = static_cast<float>(input[i]);
T coshy = static_cast<T>(coshf(inputf));
output[i] = dout[i] / coshy;
}
return;
}
template <typename T>
__global__ void AcoshGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
float inputf = static_cast<float>(input[i]);
T sinhy = static_cast<T>(sinhf(inputf));
output[i] = dout[i] / sinhy;
}
return;
}
template <typename T>
void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
SqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
@ -98,6 +125,24 @@ void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cuda
return;
}
template <typename T>
void AtanGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
AtanGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
}
template <typename T>
void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
AsinhGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
}
template <typename T>
void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
AcoshGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
}
template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
@ -106,6 +151,12 @@ template void AsinGrad<float>(const float *input, const float *dout, float *outp
cudaStream_t cuda_stream);
template void ACosGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void AtanGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void AsinhGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void AcoshGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void SqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void RsqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
@ -114,3 +165,9 @@ template void AsinGrad<half>(const half *input, const half *dout, half *output,
cudaStream_t cuda_stream);
template void ACosGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void AtanGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void AsinhGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void AcoshGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);

@ -26,5 +26,12 @@ template <typename T>
void AsinGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void AtanGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void AsinhGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void AcoshGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_

@ -146,6 +146,15 @@ __global__ void AsinKernel(const T *input, T *output, const size_t count) {
return;
}
template <typename T>
__global__ void AsinhKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
float inputf = static_cast<float>(input[i]);
T res = static_cast<T>(asinhf(inputf));
output[i] = res;
}
return;
}
template <typename T>
__global__ void CosKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = cos(input[i]);
@ -169,6 +178,24 @@ __global__ void ACosKernel(const T *input, T *output, const size_t count) {
return;
}
template <typename T>
__global__ void AcoshKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
float inputf = static_cast<float>(input[i]);
T res = static_cast<T>(acoshf(inputf));
output[i] = res;
}
return;
}
template <typename T>
__global__ void AtanKernel(const T *input, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
float inputf = static_cast<float>(input[i]);
T res = static_cast<T>(atanf(inputf));
output[i] = res;
}
return;
}
template <typename T>
__global__ void ZeroslikeKernel(T *output, const size_t count) {
T zero = 0.0;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@ -281,6 +308,21 @@ void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_strea
return;
}
template <typename T>
void Atan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
AtanKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
AsinhKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
AcoshKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
}
template <typename T>
void Rsqrt(const T *input, T *output, const size_t count, cudaStream_t cuda_stream) {
RsqrtKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, output, count);
return;
@ -315,6 +357,9 @@ template void Sin<float>(const float *input, float *output, const size_t count,
template void Cos<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Asin<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void ACos<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Atan<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Asinh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Acosh<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<float>(float *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<float>(const float *input, float *output, const size_t count, cudaStream_t cuda_stream);
@ -333,6 +378,9 @@ template void Sin<half>(const half *input, half *output, const size_t count, cud
template void Cos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Asin<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void ACos<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Atan<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Asinh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Acosh<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Rsqrt<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);
template void Zeroslike<half>(half *output, const size_t count, cudaStream_t cuda_stream);
template void Abs<half>(const half *input, half *output, const size_t count, cudaStream_t cuda_stream);

@ -49,6 +49,12 @@ void Asin(const T *input, T *output, const size_t count, cudaStream_t cuda_strea
template <typename T>
void ACos(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Atan(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Asinh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Acosh(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Zeroslike(T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void Abs(const T *input, T *output, const size_t count, cudaStream_t cuda_stream);

@ -74,6 +74,10 @@ MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
@ -82,6 +86,14 @@ MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Abs, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),

@ -44,6 +44,9 @@ enum UnaryOptype {
UNARY_OP_COS,
UNARY_OP_ASIN,
UNARY_OP_ACOS,
UNARY_OP_ATAN,
UNARY_OP_ASINH,
UNARY_OP_ACOSH,
UNARY_OP_ABS,
UNARY_OP_FLOOR,
UNARY_OP_INVALID_TYPE = 255
@ -64,6 +67,9 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY
{"Cos", UNARY_OP_COS},
{"Asin", UNARY_OP_ASIN},
{"ACos", UNARY_OP_ACOS},
{"Atan", UNARY_OP_ATAN},
{"Asinh", UNARY_OP_ASINH},
{"Acosh", UNARY_OP_ACOSH},
{"Abs", UNARY_OP_ABS},
{"Floor", UNARY_OP_FLOOR}};
template <typename T>
@ -142,6 +148,18 @@ class UnaryOpGpuKernel : public GpuKernel {
ACos(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ATAN: {
Atan(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ASINH: {
Asinh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ACOSH: {
Acosh(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ZEROSLIKE: {
Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast<cudaStream_t>(stream_ptr));
return true;

@ -50,5 +50,29 @@ MS_REG_GPU_KERNEL_ONE(
ACosGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
AtanGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryGradOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
AtanGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
AsinhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryGradOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
AsinhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
AcoshGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryGradOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
AcoshGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
} // namespace kernel
} // namespace mindspore

@ -32,12 +32,16 @@ enum UnaryGradOptype {
UNARY_OP_RSQRT_GRAD = 1,
UNARY_OP_ASIN_GRAD = 2,
UNARY_OP_ACOS_GRAD = 3,
UNARY_OP_ATAN_GRAD = 4,
UNARY_OP_ASINH_GRAD = 5,
UNARY_OP_ACOSH_GRAD = 6,
UNARY_OP_GRAD_INVALID_TYPE = 255
};
static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {{"SqrtGrad", UNARY_OP_SQRT_GRAD},
{"RsqrtGrad", UNARY_OP_RSQRT_GRAD},
{"AsinGrad", UNARY_OP_ASIN_GRAD},
{"ACosGrad", UNARY_OP_ACOS_GRAD}};
static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {
{"SqrtGrad", UNARY_OP_SQRT_GRAD}, {"RsqrtGrad", UNARY_OP_RSQRT_GRAD}, {"AsinGrad", UNARY_OP_ASIN_GRAD},
{"ACosGrad", UNARY_OP_ACOS_GRAD}, {"AtanGrad", UNARY_OP_ATAN_GRAD}, {"AsinhGrad", UNARY_OP_ASINH_GRAD},
{"AcoshGrad", UNARY_OP_ACOSH_GRAD}};
template <typename T>
class UnaryGradOpGpuKernel : public GpuKernel {
public:
@ -77,6 +81,21 @@ class UnaryGradOpGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ATAN_GRAD: {
AtanGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ASINH_GRAD: {
AsinhGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ACOSH_GRAD: {
AcoshGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_RSQRT_GRAD: {
RsqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations._grad_ops as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
np.random.seed(1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_acoshgrad_fp32():
y_np = np.random.rand(4, 2).astype(np.float32) * 10
dout_np = np.random.rand(4, 2).astype(np.float32) * 10
output_ms = P.AcoshGrad()(Tensor(y_np), Tensor(dout_np))
output_np = dout_np / np.sinh(y_np)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_acoshgrad_fp16():
y_np = np.random.rand(4, 2).astype(np.float16) * 10
dout_np = np.random.rand(4, 2).astype(np.float16) * 10
output_ms = P.AcoshGrad()(Tensor(y_np), Tensor(dout_np))
output_np = dout_np.astype(np.float32) / np.sinh(y_np).astype(np.float32)
assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3)

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
np.random.seed(1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_acosh_fp32():
x_np = np.random.rand(4, 2).astype(np.float32) * 10 + 1
output_ms = P.Acosh()(Tensor(x_np))
output_np = np.arccosh(x_np)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_acosh_fp16():
x_np = np.random.rand(4, 2).astype(np.float16) * 10 + 1
output_ms = P.Acosh()(Tensor(x_np))
output_np = np.arccosh(x_np.astype(np.float32)).astype(np.float16)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations._grad_ops as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
np.random.seed(1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asinhgrad_fp32():
y_np = np.random.rand(4, 2).astype(np.float32) * 10
dout_np = np.random.rand(4, 2).astype(np.float32) * 10
output_ms = P.AsinhGrad()(Tensor(y_np), Tensor(dout_np))
output_np = dout_np / np.cosh(y_np)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asinhgrad_fp16():
y_np = np.random.rand(4, 2).astype(np.float16) * 10
dout_np = np.random.rand(4, 2).astype(np.float16) * 10
output_ms = P.AsinhGrad()(Tensor(y_np), Tensor(dout_np))
output_np = dout_np.astype(np.float32) / np.cosh(y_np).astype(np.float32)
assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3)

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
np.random.seed(1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asinh_fp32():
x_np = np.random.rand(4, 2).astype(np.float32) * 10
output_ms = P.Asinh()(Tensor(x_np))
output_np = np.arcsinh(x_np)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asinh_fp16():
x_np = np.random.rand(4, 2).astype(np.float16) * 10
output_ms = P.Asinh()(Tensor(x_np))
output_np = np.arcsinh(x_np.astype(np.float32)).astype(np.float16)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations._grad_ops as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
np.random.seed(1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atangrad_fp32():
x_np = np.random.rand(4, 2).astype(np.float32) * 10
dout_np = np.random.rand(4, 2).astype(np.float32) * 10
output_ms = P.AtanGrad()(Tensor(x_np), Tensor(dout_np))
output_np = dout_np / (1 + np.square(x_np))
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atangrad_fp16():
x_np = np.random.rand(4, 2).astype(np.float16) * 10
dout_np = np.random.rand(4, 2).astype(np.float16) * 10
output_ms = P.AtanGrad()(Tensor(x_np), Tensor(dout_np))
output_np = dout_np.astype(np.float32) / (1 + np.square(x_np.astype(np.float32)))
assert np.allclose(output_ms.asnumpy(), output_np.astype(np.float16), 1e-3, 1e-3)

@ -0,0 +1,41 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.ops import operations as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
np.random.seed(1)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atan_fp32():
x_np = np.random.rand(4, 2).astype(np.float32) * 10
output_ms = P.Atan()(Tensor(x_np))
output_np = np.arctan(x_np)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-4, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_atan_fp16():
x_np = np.random.rand(4, 2).astype(np.float16) * 10
output_ms = P.Atan()(Tensor(x_np))
output_np = np.arctan(x_np.astype(np.float32)).astype(np.float16)
assert np.allclose(output_ms.asnumpy(), output_np, 1e-3, 1e-3)
Loading…
Cancel
Save