Refator eltwisegrad cpu ops

pull/14436/head
wuxuejian 4 years ago
parent d346a861bc
commit 1d5f77d075

@ -18,11 +18,13 @@
#include <memory>
#include <vector>
#include <limits>
#include <string>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class EltWiseGradCPUKernel : public CPUKernel {
public:
EltWiseGradCPUKernel() = default;
@ -32,95 +34,75 @@ class EltWiseGradCPUKernel : public CPUKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
private:
template <typename T>
void ReluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void ReLU6Grad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AbsGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void SigmoidGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void SqrtGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void TanhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void GeluGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AsinGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void ACosGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AtanGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AsinhGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
template <typename T>
void AcoshGrad(const T *input1, const T *input2, T *out, size_t start, size_t end);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;
std::vector<size_t> input_element_num1_;
std::vector<size_t> output_shape_;
std::vector<size_t> output_element_num_;
OperateType operate_type_{RELUGRAD};
TypeId dtype_{kTypeUnknown};
std::string kernel_name_ = "";
};
MS_REG_CPU_KERNEL(
MS_REG_CPU_KERNEL_T(
ReluGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
ReLU6Grad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
SigmoidGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
SqrtGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
TanhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(GeLUGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(GeLUGrad,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AsinGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
ACosGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AtanGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AsinhGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
MS_REG_CPU_KERNEL(
EltWiseGradCPUKernel, float);
MS_REG_CPU_KERNEL_T(
AcoshGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
EltWiseGradCPUKernel);
EltWiseGradCPUKernel, float);
} // namespace kernel
} // namespace mindspore

Loading…
Cancel
Save