add sinh, cosh, asinh, acosh, atanh, atan2, asinhgrad, acoshgrad for cpu

pull/12591/head
wangrao 4 years ago
parent 095d7fb877
commit ab17c49eba

@ -16,6 +16,7 @@
#include <cmath>
#include <string>
#include <thread>
#include <map>
#include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
@ -235,45 +236,40 @@ void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out,
}
}
template <typename T>
void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = atan2(input1[idx[0]], input2[idx[1]]);
}
}
static const std::map<std::string, OperateType> kArithmeticBinOpTypeMap = {
{prim::kPrimGreater->name(), GREATER},
{prim::kPrimAdd->name(), ADD},
{prim::kPrimGreaterEqual->name(), GREATEREQUAL},
{prim::kPrimSub->name(), SUB},
{prim::kPrimLogicalAnd->name(), LOGICALAND},
{prim::kPrimMul->name(), MUL},
{prim::kPrimLessEqual->name(), LESSEQUAL},
{prim::kPrimDiv->name(), DIV},
{prim::kPrimLogicalOr->name(), LOGICALOR},
{prim::kPrimMod->name(), MOD},
{prim::kPrimAssignAdd->name(), ASSIGNADD},
{prim::kPrimPow->name(), POW},
{prim::kPrimFloorDiv->name(), FLOORDIV},
{prim::kPrimLess->name(), LESS},
{prim::kPrimNotEqual->name(), NOTEQUAL},
{prim::kPrimAtan2->name(), ATAN2},
{prim::kPrimRealDiv->name(), REALDIV},
{prim::kPrimEqual->name(), EQUAL},
{prim::kPrimSquaredDifference->name(), SQUAREDDIFFERENCE}};
void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == prim::kPrimAdd->name()) {
operate_type_ = ADD;
} else if (kernel_name == prim::kPrimSub->name()) {
operate_type_ = SUB;
} else if (kernel_name == prim::kPrimMul->name()) {
operate_type_ = MUL;
} else if (kernel_name == prim::kPrimRealDiv->name()) {
operate_type_ = REALDIV;
} else if (kernel_name == prim::kPrimDiv->name()) {
operate_type_ = DIV;
} else if (kernel_name == prim::kPrimFloorDiv->name()) {
operate_type_ = FLOORDIV;
} else if (kernel_name == prim::kPrimMod->name()) {
operate_type_ = MOD;
} else if (kernel_name == prim::kPrimPow->name()) {
operate_type_ = POW;
} else if (kernel_name == prim::kPrimLess->name()) {
operate_type_ = LESS;
} else if (kernel_name == prim::kPrimEqual->name()) {
operate_type_ = EQUAL;
} else if (kernel_name == prim::kPrimNotEqual->name()) {
operate_type_ = NOTEQUAL;
} else if (kernel_name == prim::kPrimGreater->name()) {
operate_type_ = GREATER;
} else if (kernel_name == prim::kPrimGreaterEqual->name()) {
operate_type_ = GREATEREQUAL;
} else if (kernel_name == prim::kPrimLessEqual->name()) {
operate_type_ = LESSEQUAL;
} else if (kernel_name == prim::kPrimLogicalAnd->name()) {
operate_type_ = LOGICALAND;
} else if (kernel_name == prim::kPrimLogicalOr->name()) {
operate_type_ = LOGICALOR;
} else if (kernel_name == prim::kPrimAssignAdd->name()) {
operate_type_ = ASSIGNADD;
} else if (kernel_name == prim::kPrimSquaredDifference->name()) {
operate_type_ = SQUAREDDIFFERENCE;
if (kArithmeticBinOpTypeMap.find(kernel_name) != kArithmeticBinOpTypeMap.end()) {
operate_type_ = kArithmeticBinOpTypeMap.at(kernel_name);
} else {
MS_LOG(EXCEPTION) << "Not support " << kernel_name;
}
@ -448,6 +444,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Pow<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ASSIGNADD) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::AssignAdd<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ATAN2) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Atan2<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == SQUAREDDIFFERENCE) {
threads.emplace_back(
std::thread(&ArithmeticCPUKernel::SquaredDifference<T>, this, input1, input2, output, start, end));

@ -58,6 +58,8 @@ class ArithmeticCPUKernel : public CPUKernel {
template <typename T>
void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void Less(const T *input1, const T *input2, bool *out, size_t start, size_t end);
template <typename T>
void Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end);
@ -279,6 +281,10 @@ MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL(
LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
ArithmeticCPUKernel);
MS_REG_CPU_KERNEL(
Atan2,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticCPUKernel);
} // namespace kernel
} // namespace mindspore

@ -136,42 +136,68 @@ void Tan(const T *in, T *out, size_t start, size_t end) {
out[i] = tan(in[i]);
}
}
template <typename T>
void Sinh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = sinh(in[i]);
}
}
template <typename T>
void Cosh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = cosh(in[i]);
}
}
template <typename T>
void Asinh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = asinh(in[i]);
}
}
template <typename T>
void Acosh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = acosh(in[i]);
}
}
template <typename T>
void Atanh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = atanh(in[i]);
}
}
} // namespace
static const std::map<std::string, OperateType> kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG},
{prim::kPrimSquare->name(), SQUARE},
{prim::kPrimOnesLike->name(), ONESLIKE},
{prim::kPrimZerosLike->name(), ZEROSLIKE},
{prim::kPrimLogicalNot->name(), LOGICALNOT},
{prim::kPrimSign->name(), SIGN},
{prim::kPrimFloor->name(), FLOOR},
{prim::kPrimReciprocal->name(), RECIPROCAL},
{prim::kPrimGeLU->name(), GELU},
{prim::kPrimAsin->name(), ASIN},
{prim::kPrimACos->name(), ACOS},
{prim::kPrimAtan->name(), ATAN},
{prim::kPrimSin->name(), SIN},
{prim::kPrimCos->name(), COS},
{prim::kPrimTan->name(), TAN},
{prim::kPrimSinh->name(), SINH},
{prim::kPrimCosh->name(), COSH},
{prim::kPrimAsinh->name(), ASINH},
{prim::kPrimAcosh->name(), ACOSH},
{prim::kPrimAtanh->name(), ATANH}};
void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == prim::kPrimSquare->name()) {
operate_type_ = SQUARE;
} else if (kernel_name == prim::kPrimOnesLike->name()) {
operate_type_ = ONESLIKE;
} else if (kernel_name == prim::kPrimZerosLike->name()) {
operate_type_ = ZEROSLIKE;
} else if (kernel_name == prim::kPrimNeg->name()) {
operate_type_ = NEG;
} else if (kernel_name == prim::kPrimLogicalNot->name()) {
operate_type_ = LOGICALNOT;
} else if (kernel_name == prim::kPrimSign->name()) {
operate_type_ = SIGN;
} else if (kernel_name == prim::kPrimFloor->name()) {
operate_type_ = FLOOR;
} else if (kernel_name == prim::kPrimReciprocal->name()) {
operate_type_ = RECIPROCAL;
} else if (kernel_name == prim::kPrimGeLU->name()) {
operate_type_ = GELU;
} else if (kernel_name == prim::kPrimAsin->name()) {
operate_type_ = ASIN;
} else if (kernel_name == prim::kPrimACos->name()) {
operate_type_ = ACOS;
} else if (kernel_name == prim::kPrimAtan->name()) {
operate_type_ = ATAN;
} else if (kernel_name == prim::kPrimSin->name()) {
operate_type_ = SIN;
} else if (kernel_name == prim::kPrimCos->name()) {
operate_type_ = COS;
} else if (kernel_name == prim::kPrimTan->name()) {
operate_type_ = TAN;
}
operate_type_ = kArithmeticOpTypeMap.at(kernel_name);
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
}
@ -259,7 +285,10 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
{GELU, Gelu<T>}, {SIN, Sin<T>},
{COS, Cos<T>}, {TAN, Tan<T>},
{ASIN, Asin<T>}, {ACOS, ACos<T>},
{ATAN, Atan<T>}};
{ATAN, Atan<T>}, {SINH, Sinh<T>},
{COSH, Cosh<T>}, {ASINH, Asinh<T>},
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>}};
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end));

@ -72,27 +72,25 @@ MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutp
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ArithmeticSelfCPUKernel);
} // namespace kernel
} // namespace mindspore

@ -102,6 +102,14 @@ enum OperateType {
SIN,
COS,
TAN,
SINH,
COSH,
ASINH,
ACOSH,
ATANH,
ASINHGRAD,
ACOSHGRAD,
ATAN2,
};
class CPUKernel : public kernel::KernelMod {

@ -153,6 +153,48 @@ void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, si
}
}
template <typename T>
void EltWiseGradCPUKernel::AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = sqrt(1 + input1[i] * input1[i]);
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
out[i] = dividend / divisor;
}
}
template <typename T>
void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T dividend = input2[i];
T divisor = sqrt(input1[i] * input1[i] - 1);
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
out[i] = dividend / divisor;
}
}
void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
@ -176,6 +218,10 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
operate_type_ = ACOSGRAD;
} else if (kernel_name == "AtanGrad") {
operate_type_ = ATANGRAD;
} else if (kernel_name == "AsinhGrad") {
operate_type_ = ASINHGRAD;
} else if (kernel_name == "AcoshGrad") {
operate_type_ = ACOSHGRAD;
} else {
MS_LOG(EXCEPTION) << "Not support " << kernel_name;
}
@ -263,6 +309,10 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ATANGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AtanGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ASINHGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinhGrad<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ACOSHGRAD) {
threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AcoshGrad<T>, this, input1, input2, output, start, end));
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}

@ -56,6 +56,10 @@ class EltWiseGradCPUKernel : public CPUKernel {
void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;
@ -101,22 +105,21 @@ MS_REG_CPU_KERNEL(
AsinGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
AsinGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
ACosGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AtanGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
AtanGrad,
AsinhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
AtanGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
AcoshGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
} // namespace kernel
} // namespace mindspore

@ -195,9 +195,15 @@ inline const PrimitivePtr kPrimLogSoftmax = std::make_shared<Primitive>("LogSoft
inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared<Primitive>("LogSoftmaxGrad");
inline const PrimitivePtr kPrimLstm = std::make_shared<Primitive>("Lstm");
inline const PrimitivePtr kPrimTan = std::make_shared<Primitive>("Tan");
inline const PrimitivePtr kPrimAtan2 = std::make_shared<Primitive>("Atan2");
inline const PrimitivePtr kPrimAtan = std::make_shared<Primitive>("Atan");
inline const PrimitivePtr kPrimAsin = std::make_shared<Primitive>("Asin");
inline const PrimitivePtr kPrimSinh = std::make_shared<Primitive>("Sinh");
inline const PrimitivePtr kPrimCosh = std::make_shared<Primitive>("Cosh");
inline const PrimitivePtr kPrimTanh = std::make_shared<Primitive>("Tanh");
inline const PrimitivePtr kPrimAsinh = std::make_shared<Primitive>("Asinh");
inline const PrimitivePtr kPrimAcosh = std::make_shared<Primitive>("Acosh");
inline const PrimitivePtr kPrimAtanh = std::make_shared<Primitive>("Atanh");
inline const PrimitivePtr kPrimTanhGrad = std::make_shared<Primitive>("TanhGrad");
inline const PrimitivePtr kPrimPooling = std::make_shared<Primitive>("Pooling");
inline const PrimitivePtr kPrimPoolingGrad = std::make_shared<Primitive>("PoolingGrad");

@ -2561,7 +2561,7 @@ class Acosh(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> acosh = ops.Acosh()
@ -2597,7 +2597,7 @@ class Cosh(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> cosh = ops.Cosh()
@ -2638,7 +2638,7 @@ class Asinh(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend`` ``GPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> asinh = ops.Asinh()
@ -2671,7 +2671,7 @@ class Sinh(PrimitiveWithInfer):
Tensor, has the same shape as `input_x`.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> sinh = ops.Sinh()
@ -3886,7 +3886,7 @@ class Atanh(PrimitiveWithInfer):
TypeError: If `input_x` is not a Tensor.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32)
@ -3931,7 +3931,7 @@ class Atan2(_MathBinaryOp):
TypeError: If `input_x` or `input_y` is not a Tensor.
Supported Platforms:
``Ascend``
``Ascend`` ``CPU``
Examples:
>>> input_x = Tensor(np.array([0, 1]), mindspore.float32)

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAcoshGrad(nn.Cell):
def __init__(self):
super(NetAcoshGrad, self).__init__()
self.acoshGrad = G.AcoshGrad()
def construct(self, x, dy):
return self.acoshGrad(x, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_acosh_grad():
x = np.array([5, 4, 3]).astype('float32')
dy = np.array([1, 0, -1]).astype('float32')
acosh_grad = NetAcoshGrad()
output = acosh_grad(Tensor(x), Tensor(dy))
print(output)
expect = dy / np.sqrt(x * x - 1)
assert np.allclose(output.asnumpy(), expect)

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAcosh(nn.Cell):
def __init__(self):
super(NetAcosh, self).__init__()
self.acosh = P.Acosh()
def construct(self, x):
return self.acosh(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_acosh():
np_array = np.array([1, 2, 3, 4, 5]).astype('float32')
input_x = Tensor(np_array)
net = NetAcosh()
output = net(input_x)
print(output)
expect = np.arccosh(np_array)
assert np.allclose(output.asnumpy(), expect)

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops.operations import _grad_ops as G
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAsinhGrad(nn.Cell):
def __init__(self):
super(NetAsinhGrad, self).__init__()
self.asinhGrad = G.AsinhGrad()
def construct(self, x, dy):
return self.asinhGrad(x, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_asinh_grad():
x = np.array([-0.5, 0, 0.5]).astype('float32')
dy = np.array([1, 0, -1]).astype('float32')
asinh_grad = NetAsinhGrad()
output = asinh_grad(Tensor(x), Tensor(dy))
print(output)
expect = dy / np.sqrt(1 + x * x)
assert np.allclose(output.asnumpy(), expect)

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAsinh(nn.Cell):
def __init__(self):
super(NetAsinh, self).__init__()
self.asinh = P.Asinh()
def construct(self, x):
return self.asinh(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_asinh():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetAsinh()
output = net(input_x)
print(output)
expect = np.arcsinh(np_array)
assert np.allclose(output.asnumpy(), expect)

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAtan2(nn.Cell):
def __init__(self):
super(NetAtan2, self).__init__()
self.atan2 = P.Atan2()
def construct(self, x, y):
return self.atan2(x, y)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atan2():
np_array = np.array([1, 2, 3, 4, 5]).astype('float32')
input_x = Tensor(np_array)
net = NetAtan2()
output = net(input_x, input_x)
print(output)
expect = np.arctan2(np_array, np_array)
assert np.allclose(output.asnumpy(), expect)

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetAtanh(nn.Cell):
def __init__(self):
super(NetAtanh, self).__init__()
self.atanh = P.Atanh()
def construct(self, x):
return self.atanh(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_atanh():
np_array = np.array([-0.5, 0, 0.5]).astype('float32')
input_x = Tensor(np_array)
net = NetAtanh()
output = net(input_x)
print(output)
expect = np.arctanh(np_array)
assert np.allclose(output.asnumpy(), expect)

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetCosh(nn.Cell):
def __init__(self):
super(NetCosh, self).__init__()
self.cosh = P.Cosh()
def construct(self, x):
return self.cosh(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_cosh():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetCosh()
output = net(input_x)
print(output)
expect = np.cosh(np_array)
assert np.allclose(output.asnumpy(), expect)

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
class NetSinh(nn.Cell):
def __init__(self):
super(NetSinh, self).__init__()
self.sinh = P.Sinh()
def construct(self, x):
return self.sinh(x)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_sinh():
np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32')
input_x = Tensor(np_array)
net = NetSinh()
output = net(input_x)
print(output)
expect = np.sinh(np_array)
assert np.allclose(output.asnumpy(), expect)
Loading…
Cancel
Save