From ab17c49eba1ba3ee2300484ec746aa04261be4ea Mon Sep 17 00:00:00 2001 From: wangrao Date: Mon, 22 Feb 2021 10:24:18 +0800 Subject: [PATCH] add sinh, cosh, asinh, acosh, atanh, atan2, asinhgrad, acoshgrad for cpu --- .../cpu/arithmetic_cpu_kernel.cc | 70 +++++++------- .../cpu/arithmetic_cpu_kernel.h | 6 ++ .../cpu/arithmetic_self_cpu_kernel.cc | 93 ++++++++++++------- .../cpu/arithmetic_self_cpu_kernel.h | 20 ++-- .../backend/kernel_compiler/cpu/cpu_kernel.h | 8 ++ .../cpu/eltwise_grad_cpu_kernel.cc | 50 ++++++++++ .../cpu/eltwise_grad_cpu_kernel.h | 15 +-- mindspore/core/base/core_ops.h | 6 ++ mindspore/ops/operations/math_ops.py | 12 +-- tests/st/ops/cpu/test_acosh_grad_op.py | 46 +++++++++ tests/st/ops/cpu/test_acosh_op.py | 46 +++++++++ tests/st/ops/cpu/test_asinh_grad_op.py | 46 +++++++++ tests/st/ops/cpu/test_asinh_op.py | 46 +++++++++ tests/st/ops/cpu/test_atan2_op.py | 46 +++++++++ tests/st/ops/cpu/test_atanh_op.py | 46 +++++++++ tests/st/ops/cpu/test_cosh_op.py | 46 +++++++++ tests/st/ops/cpu/test_sinh_op.py | 46 +++++++++ 17 files changed, 557 insertions(+), 91 deletions(-) create mode 100644 tests/st/ops/cpu/test_acosh_grad_op.py create mode 100644 tests/st/ops/cpu/test_acosh_op.py create mode 100644 tests/st/ops/cpu/test_asinh_grad_op.py create mode 100644 tests/st/ops/cpu/test_asinh_op.py create mode 100644 tests/st/ops/cpu/test_atan2_op.py create mode 100644 tests/st/ops/cpu/test_atanh_op.py create mode 100644 tests/st/ops/cpu/test_cosh_op.py create mode 100644 tests/st/ops/cpu/test_sinh_op.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc index a531095b33..82920c84ee 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -235,45 +236,40 @@ void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, } } +template +void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + std::vector idx; + GenIndex(i, &idx); + out[i] = atan2(input1[idx[0]], input2[idx[1]]); + } +} +static const std::map kArithmeticBinOpTypeMap = { + {prim::kPrimGreater->name(), GREATER}, + {prim::kPrimAdd->name(), ADD}, + {prim::kPrimGreaterEqual->name(), GREATEREQUAL}, + {prim::kPrimSub->name(), SUB}, + {prim::kPrimLogicalAnd->name(), LOGICALAND}, + {prim::kPrimMul->name(), MUL}, + {prim::kPrimLessEqual->name(), LESSEQUAL}, + {prim::kPrimDiv->name(), DIV}, + {prim::kPrimLogicalOr->name(), LOGICALOR}, + {prim::kPrimMod->name(), MOD}, + {prim::kPrimAssignAdd->name(), ASSIGNADD}, + {prim::kPrimPow->name(), POW}, + {prim::kPrimFloorDiv->name(), FLOORDIV}, + {prim::kPrimLess->name(), LESS}, + {prim::kPrimNotEqual->name(), NOTEQUAL}, + {prim::kPrimAtan2->name(), ATAN2}, + {prim::kPrimRealDiv->name(), REALDIV}, + {prim::kPrimEqual->name(), EQUAL}, + {prim::kPrimSquaredDifference->name(), SQUAREDDIFFERENCE}}; + void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == prim::kPrimAdd->name()) { - operate_type_ = ADD; - } else if (kernel_name == prim::kPrimSub->name()) { - operate_type_ = SUB; - } else if (kernel_name == prim::kPrimMul->name()) { - operate_type_ = MUL; - } else if (kernel_name == prim::kPrimRealDiv->name()) { - operate_type_ = REALDIV; - } else if (kernel_name == prim::kPrimDiv->name()) { - operate_type_ = DIV; - } else if (kernel_name == prim::kPrimFloorDiv->name()) { - operate_type_ = FLOORDIV; - } else if (kernel_name == prim::kPrimMod->name()) { - operate_type_ = MOD; - } else if (kernel_name == prim::kPrimPow->name()) { - operate_type_ = POW; - } else if (kernel_name == prim::kPrimLess->name()) { - operate_type_ = LESS; - } else if (kernel_name == prim::kPrimEqual->name()) { - operate_type_ = EQUAL; - } else if (kernel_name == prim::kPrimNotEqual->name()) { - operate_type_ = NOTEQUAL; - } else if (kernel_name == prim::kPrimGreater->name()) { - operate_type_ = GREATER; - } else if (kernel_name == prim::kPrimGreaterEqual->name()) { - operate_type_ = GREATEREQUAL; - } else if (kernel_name == prim::kPrimLessEqual->name()) { - operate_type_ = LESSEQUAL; - } else if (kernel_name == prim::kPrimLogicalAnd->name()) { - operate_type_ = LOGICALAND; - } else if (kernel_name == prim::kPrimLogicalOr->name()) { - operate_type_ = LOGICALOR; - } else if (kernel_name == prim::kPrimAssignAdd->name()) { - operate_type_ = ASSIGNADD; - } else if (kernel_name == prim::kPrimSquaredDifference->name()) { - operate_type_ = SQUAREDDIFFERENCE; + if (kArithmeticBinOpTypeMap.find(kernel_name) != kArithmeticBinOpTypeMap.end()) { + operate_type_ = kArithmeticBinOpTypeMap.at(kernel_name); } else { MS_LOG(EXCEPTION) << "Not support " << kernel_name; } @@ -448,6 +444,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector &inputs, co threads.emplace_back(std::thread(&ArithmeticCPUKernel::Pow, this, input1, input2, output, start, end)); } else if (operate_type_ == ASSIGNADD) { threads.emplace_back(std::thread(&ArithmeticCPUKernel::AssignAdd, this, input1, input2, output, start, end)); + } else if (operate_type_ == ATAN2) { + threads.emplace_back(std::thread(&ArithmeticCPUKernel::Atan2, this, input1, input2, output, start, end)); } else if (operate_type_ == SQUAREDDIFFERENCE) { threads.emplace_back( std::thread(&ArithmeticCPUKernel::SquaredDifference, this, input1, input2, output, start, end)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h index 2fc2923dfa..cc2ab1a4a3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h @@ -58,6 +58,8 @@ class ArithmeticCPUKernel : public CPUKernel { template void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end); template + void Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end); + template void Less(const T *input1, const T *input2, bool *out, size_t start, size_t end); template void Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end); @@ -279,6 +281,10 @@ MS_REG_CPU_KERNEL( MS_REG_CPU_KERNEL( LogicalOr, KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), ArithmeticCPUKernel); +MS_REG_CPU_KERNEL( + Atan2, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc index a779ce551c..04bfd4997e 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc @@ -136,42 +136,68 @@ void Tan(const T *in, T *out, size_t start, size_t end) { out[i] = tan(in[i]); } } + +template +void Sinh(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = sinh(in[i]); + } +} + +template +void Cosh(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = cosh(in[i]); + } +} + +template +void Asinh(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = asinh(in[i]); + } +} + +template +void Acosh(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = acosh(in[i]); + } +} + +template +void Atanh(const T *in, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + out[i] = atanh(in[i]); + } +} } // namespace +static const std::map kArithmeticOpTypeMap = {{prim::kPrimNeg->name(), NEG}, + {prim::kPrimSquare->name(), SQUARE}, + {prim::kPrimOnesLike->name(), ONESLIKE}, + {prim::kPrimZerosLike->name(), ZEROSLIKE}, + {prim::kPrimLogicalNot->name(), LOGICALNOT}, + {prim::kPrimSign->name(), SIGN}, + {prim::kPrimFloor->name(), FLOOR}, + {prim::kPrimReciprocal->name(), RECIPROCAL}, + {prim::kPrimGeLU->name(), GELU}, + {prim::kPrimAsin->name(), ASIN}, + {prim::kPrimACos->name(), ACOS}, + {prim::kPrimAtan->name(), ATAN}, + {prim::kPrimSin->name(), SIN}, + {prim::kPrimCos->name(), COS}, + {prim::kPrimTan->name(), TAN}, + {prim::kPrimSinh->name(), SINH}, + {prim::kPrimCosh->name(), COSH}, + {prim::kPrimAsinh->name(), ASINH}, + {prim::kPrimAcosh->name(), ACOSH}, + {prim::kPrimAtanh->name(), ATANH}}; + void ArithmeticSelfCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - if (kernel_name == prim::kPrimSquare->name()) { - operate_type_ = SQUARE; - } else if (kernel_name == prim::kPrimOnesLike->name()) { - operate_type_ = ONESLIKE; - } else if (kernel_name == prim::kPrimZerosLike->name()) { - operate_type_ = ZEROSLIKE; - } else if (kernel_name == prim::kPrimNeg->name()) { - operate_type_ = NEG; - } else if (kernel_name == prim::kPrimLogicalNot->name()) { - operate_type_ = LOGICALNOT; - } else if (kernel_name == prim::kPrimSign->name()) { - operate_type_ = SIGN; - } else if (kernel_name == prim::kPrimFloor->name()) { - operate_type_ = FLOOR; - } else if (kernel_name == prim::kPrimReciprocal->name()) { - operate_type_ = RECIPROCAL; - } else if (kernel_name == prim::kPrimGeLU->name()) { - operate_type_ = GELU; - } else if (kernel_name == prim::kPrimAsin->name()) { - operate_type_ = ASIN; - } else if (kernel_name == prim::kPrimACos->name()) { - operate_type_ = ACOS; - } else if (kernel_name == prim::kPrimAtan->name()) { - operate_type_ = ATAN; - } else if (kernel_name == prim::kPrimSin->name()) { - operate_type_ = SIN; - } else if (kernel_name == prim::kPrimCos->name()) { - operate_type_ = COS; - } else if (kernel_name == prim::kPrimTan->name()) { - operate_type_ = TAN; - } + operate_type_ = kArithmeticOpTypeMap.at(kernel_name); dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); target_dtype_ = AnfAlgo::GetOutputInferDataType(kernel_node, 0); } @@ -259,7 +285,10 @@ void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector &inputs {GELU, Gelu}, {SIN, Sin}, {COS, Cos}, {TAN, Tan}, {ASIN, Asin}, {ACOS, ACos}, - {ATAN, Atan}}; + {ATAN, Atan}, {SINH, Sinh}, + {COSH, Cosh}, {ASINH, Asinh}, + {ACOSH, Acosh}, {ATANH, Atanh}}; + while (start < lens) { size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end)); diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h index 5922294976..0a68b3722b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h @@ -72,27 +72,25 @@ MS_REG_CPU_KERNEL(LogicalNot, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutp ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); -MS_REG_CPU_KERNEL(Asin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); -MS_REG_CPU_KERNEL(ACos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); -MS_REG_CPU_KERNEL(Atan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); -MS_REG_CPU_KERNEL(Sin, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); -MS_REG_CPU_KERNEL(Cos, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - ArithmeticSelfCPUKernel); MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); -MS_REG_CPU_KERNEL(Tan, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), +MS_REG_CPU_KERNEL(Sinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Cosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Asinh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Acosh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + ArithmeticSelfCPUKernel); +MS_REG_CPU_KERNEL(Atanh, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), ArithmeticSelfCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h index a24ce767a7..8bd3f04e26 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h @@ -102,6 +102,14 @@ enum OperateType { SIN, COS, TAN, + SINH, + COSH, + ASINH, + ACOSH, + ATANH, + ASINHGRAD, + ACOSHGRAD, + ATAN2, }; class CPUKernel : public kernel::KernelMod { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc index f7b6341963..f38180dece 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.cc @@ -153,6 +153,48 @@ void EltWiseGradCPUKernel::AtanGrad(const T *input1, const T *input2, T *out, si } } +template +void EltWiseGradCPUKernel::AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + T dividend = input2[i]; + T divisor = sqrt(1 + input1[i] * input1[i]); + if (divisor == 0) { + if (dividend == 0) { + out[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } + continue; + } + out[i] = dividend / divisor; + } +} + +template +void EltWiseGradCPUKernel::AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end) { + for (size_t i = start; i < end; i++) { + T dividend = input2[i]; + T divisor = sqrt(input1[i] * input1[i] - 1); + if (divisor == 0) { + if (dividend == 0) { + out[i] = std::numeric_limits::quiet_NaN(); + continue; + } + if (std::numeric_limits::has_infinity) { + out[i] = dividend > 0 ? std::numeric_limits::infinity() : -std::numeric_limits::infinity(); + } else { + out[i] = dividend > 0 ? std::numeric_limits::max() : std::numeric_limits::min(); + } + continue; + } + out[i] = dividend / divisor; + } +} + void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); @@ -176,6 +218,10 @@ void EltWiseGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { operate_type_ = ACOSGRAD; } else if (kernel_name == "AtanGrad") { operate_type_ = ATANGRAD; + } else if (kernel_name == "AsinhGrad") { + operate_type_ = ASINHGRAD; + } else if (kernel_name == "AcoshGrad") { + operate_type_ = ACOSHGRAD; } else { MS_LOG(EXCEPTION) << "Not support " << kernel_name; } @@ -263,6 +309,10 @@ void EltWiseGradCPUKernel::LaunchKernel(const std::vector &inputs, c threads.emplace_back(std::thread(&EltWiseGradCPUKernel::ACosGrad, this, input1, input2, output, start, end)); } else if (operate_type_ == ATANGRAD) { threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AtanGrad, this, input1, input2, output, start, end)); + } else if (operate_type_ == ASINHGRAD) { + threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AsinhGrad, this, input1, input2, output, start, end)); + } else if (operate_type_ == ACOSHGRAD) { + threads.emplace_back(std::thread(&EltWiseGradCPUKernel::AcoshGrad, this, input1, input2, output, start, end)); } else { MS_LOG(EXCEPTION) << "Not support " << operate_type_; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h index d3211d28a2..b2ed04cf1b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/eltwise_grad_cpu_kernel.h @@ -56,6 +56,10 @@ class EltWiseGradCPUKernel : public CPUKernel { void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); template void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); + template + void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end); std::vector input_shape0_; std::vector input_shape1_; std::vector input_element_num0_; @@ -101,22 +105,21 @@ MS_REG_CPU_KERNEL( AsinGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseGradCPUKernel); -MS_REG_CPU_KERNEL( - AsinGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - EltWiseGradCPUKernel); MS_REG_CPU_KERNEL( ACosGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseGradCPUKernel); MS_REG_CPU_KERNEL( - ACosGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AtanGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseGradCPUKernel); MS_REG_CPU_KERNEL( - AtanGrad, + AsinhGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseGradCPUKernel); MS_REG_CPU_KERNEL( - AtanGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + AcoshGrad, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), EltWiseGradCPUKernel); } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 6c063d349f..39c1bf92f4 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -195,9 +195,15 @@ inline const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoft inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); inline const PrimitivePtr kPrimLstm = std::make_shared("Lstm"); inline const PrimitivePtr kPrimTan = std::make_shared("Tan"); +inline const PrimitivePtr kPrimAtan2 = std::make_shared("Atan2"); inline const PrimitivePtr kPrimAtan = std::make_shared("Atan"); inline const PrimitivePtr kPrimAsin = std::make_shared("Asin"); +inline const PrimitivePtr kPrimSinh = std::make_shared("Sinh"); +inline const PrimitivePtr kPrimCosh = std::make_shared("Cosh"); inline const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); +inline const PrimitivePtr kPrimAsinh = std::make_shared("Asinh"); +inline const PrimitivePtr kPrimAcosh = std::make_shared("Acosh"); +inline const PrimitivePtr kPrimAtanh = std::make_shared("Atanh"); inline const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); inline const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); inline const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 42a6d8a910..86441d46e6 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2561,7 +2561,7 @@ class Acosh(PrimitiveWithInfer): TypeError: If `input_x` is not a Tensor. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> acosh = ops.Acosh() @@ -2597,7 +2597,7 @@ class Cosh(PrimitiveWithInfer): TypeError: If `input_x` is not a Tensor. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` Examples: >>> cosh = ops.Cosh() @@ -2638,7 +2638,7 @@ class Asinh(PrimitiveWithInfer): TypeError: If `input_x` is not a Tensor. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> asinh = ops.Asinh() @@ -2671,7 +2671,7 @@ class Sinh(PrimitiveWithInfer): Tensor, has the same shape as `input_x`. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` Examples: >>> sinh = ops.Sinh() @@ -3886,7 +3886,7 @@ class Atanh(PrimitiveWithInfer): TypeError: If `input_x` is not a Tensor. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` Examples: >>> input_x = Tensor(np.array([1.047, 0.785]), mindspore.float32) @@ -3931,7 +3931,7 @@ class Atan2(_MathBinaryOp): TypeError: If `input_x` or `input_y` is not a Tensor. Supported Platforms: - ``Ascend`` + ``Ascend`` ``CPU`` Examples: >>> input_x = Tensor(np.array([0, 1]), mindspore.float32) diff --git a/tests/st/ops/cpu/test_acosh_grad_op.py b/tests/st/ops/cpu/test_acosh_grad_op.py new file mode 100644 index 0000000000..033cd8067e --- /dev/null +++ b/tests/st/ops/cpu/test_acosh_grad_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAcoshGrad(nn.Cell): + def __init__(self): + super(NetAcoshGrad, self).__init__() + self.acoshGrad = G.AcoshGrad() + + def construct(self, x, dy): + return self.acoshGrad(x, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_acosh_grad(): + x = np.array([5, 4, 3]).astype('float32') + dy = np.array([1, 0, -1]).astype('float32') + acosh_grad = NetAcoshGrad() + output = acosh_grad(Tensor(x), Tensor(dy)) + print(output) + expect = dy / np.sqrt(x * x - 1) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_acosh_op.py b/tests/st/ops/cpu/test_acosh_op.py new file mode 100644 index 0000000000..070e6a6ad9 --- /dev/null +++ b/tests/st/ops/cpu/test_acosh_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAcosh(nn.Cell): + def __init__(self): + super(NetAcosh, self).__init__() + self.acosh = P.Acosh() + + def construct(self, x): + return self.acosh(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_acosh(): + np_array = np.array([1, 2, 3, 4, 5]).astype('float32') + input_x = Tensor(np_array) + net = NetAcosh() + output = net(input_x) + print(output) + expect = np.arccosh(np_array) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_asinh_grad_op.py b/tests/st/ops/cpu/test_asinh_grad_op.py new file mode 100644 index 0000000000..916d74b9ec --- /dev/null +++ b/tests/st/ops/cpu/test_asinh_grad_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops.operations import _grad_ops as G + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAsinhGrad(nn.Cell): + def __init__(self): + super(NetAsinhGrad, self).__init__() + self.asinhGrad = G.AsinhGrad() + + def construct(self, x, dy): + return self.asinhGrad(x, dy) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_asinh_grad(): + x = np.array([-0.5, 0, 0.5]).astype('float32') + dy = np.array([1, 0, -1]).astype('float32') + asinh_grad = NetAsinhGrad() + output = asinh_grad(Tensor(x), Tensor(dy)) + print(output) + expect = dy / np.sqrt(1 + x * x) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_asinh_op.py b/tests/st/ops/cpu/test_asinh_op.py new file mode 100644 index 0000000000..aa856b2feb --- /dev/null +++ b/tests/st/ops/cpu/test_asinh_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAsinh(nn.Cell): + def __init__(self): + super(NetAsinh, self).__init__() + self.asinh = P.Asinh() + + def construct(self, x): + return self.asinh(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_asinh(): + np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') + input_x = Tensor(np_array) + net = NetAsinh() + output = net(input_x) + print(output) + expect = np.arcsinh(np_array) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_atan2_op.py b/tests/st/ops/cpu/test_atan2_op.py new file mode 100644 index 0000000000..3a379939db --- /dev/null +++ b/tests/st/ops/cpu/test_atan2_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAtan2(nn.Cell): + def __init__(self): + super(NetAtan2, self).__init__() + self.atan2 = P.Atan2() + + def construct(self, x, y): + return self.atan2(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_atan2(): + np_array = np.array([1, 2, 3, 4, 5]).astype('float32') + input_x = Tensor(np_array) + net = NetAtan2() + output = net(input_x, input_x) + print(output) + expect = np.arctan2(np_array, np_array) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_atanh_op.py b/tests/st/ops/cpu/test_atanh_op.py new file mode 100644 index 0000000000..8bf3ed82e4 --- /dev/null +++ b/tests/st/ops/cpu/test_atanh_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetAtanh(nn.Cell): + def __init__(self): + super(NetAtanh, self).__init__() + self.atanh = P.Atanh() + + def construct(self, x): + return self.atanh(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_atanh(): + np_array = np.array([-0.5, 0, 0.5]).astype('float32') + input_x = Tensor(np_array) + net = NetAtanh() + output = net(input_x) + print(output) + expect = np.arctanh(np_array) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_cosh_op.py b/tests/st/ops/cpu/test_cosh_op.py new file mode 100644 index 0000000000..3d924c7339 --- /dev/null +++ b/tests/st/ops/cpu/test_cosh_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetCosh(nn.Cell): + def __init__(self): + super(NetCosh, self).__init__() + self.cosh = P.Cosh() + + def construct(self, x): + return self.cosh(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_cosh(): + np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') + input_x = Tensor(np_array) + net = NetCosh() + output = net(input_x) + print(output) + expect = np.cosh(np_array) + assert np.allclose(output.asnumpy(), expect) diff --git a/tests/st/ops/cpu/test_sinh_op.py b/tests/st/ops/cpu/test_sinh_op.py new file mode 100644 index 0000000000..ec25d3bbf4 --- /dev/null +++ b/tests/st/ops/cpu/test_sinh_op.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.nn as nn +from mindspore import Tensor +from mindspore import context +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") + + +class NetSinh(nn.Cell): + def __init__(self): + super(NetSinh, self).__init__() + self.sinh = P.Sinh() + + def construct(self, x): + return self.sinh(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_sinh(): + np_array = np.array([-1, -0.5, 0, 0.5, 1]).astype('float32') + input_x = Tensor(np_array) + net = NetSinh() + output = net(input_x) + print(output) + expect = np.sinh(np_array) + assert np.allclose(output.asnumpy(), expect)