|
|
|
@ -22,7 +22,6 @@
|
|
|
|
|
#include "src/runtime/allocator.h"
|
|
|
|
|
#include "nnacl/arithmetic_common.h"
|
|
|
|
|
#include "nnacl/fp32/arithmetic.h"
|
|
|
|
|
#include "schema/ops_generated.h"
|
|
|
|
|
|
|
|
|
|
typedef int (*ArithmeticRun)(float *input0, float *input1, float *output, int element_size);
|
|
|
|
|
typedef int (*ArithmeticOptRun)(float *input0, float *input1, float *output, int element_size,
|
|
|
|
@ -167,9 +166,9 @@ int DoArithmeticInferShape(const TensorPtrVector &in_tensors, const TensorPtrVec
|
|
|
|
|
|
|
|
|
|
int ChooseKernel(const int kernel_type, ArithmeticRun *arithmetic_run, ArithmeticParameter *params) {
|
|
|
|
|
if (kernel_type == KernelType::Mul) {
|
|
|
|
|
if (params->activation_type_ == mindspore::schema::ActivationType_RELU) {
|
|
|
|
|
if (params->activation_type_ == ActivationType::RELU) {
|
|
|
|
|
*arithmetic_run = ElementMulRelu;
|
|
|
|
|
} else if (params->activation_type_ == mindspore::schema::ActivationType_RELU6) {
|
|
|
|
|
} else if (params->activation_type_ == ActivationType::RELU6) {
|
|
|
|
|
*arithmetic_run = ElementMulRelu6;
|
|
|
|
|
} else {
|
|
|
|
|
*arithmetic_run = ElementMul;
|
|
|
|
@ -183,9 +182,9 @@ int ChooseKernel(const int kernel_type, ArithmeticRun *arithmetic_run, Arithmeti
|
|
|
|
|
|
|
|
|
|
int ChooseOptKernel(const int kernel_type, ArithmeticOptRun *arithmetic_opt_run, ArithmeticParameter *params) {
|
|
|
|
|
if (kernel_type == KernelType::Mul) {
|
|
|
|
|
if (params->activation_type_ == mindspore::schema::ActivationType_RELU) {
|
|
|
|
|
if (params->activation_type_ == ActivationType::RELU) {
|
|
|
|
|
*arithmetic_opt_run = ElementOptMulRelu;
|
|
|
|
|
} else if (params->activation_type_ == mindspore::schema::ActivationType_RELU6) {
|
|
|
|
|
} else if (params->activation_type_ == ActivationType::RELU6) {
|
|
|
|
|
*arithmetic_opt_run = ElementOptMulRelu6;
|
|
|
|
|
} else {
|
|
|
|
|
*arithmetic_opt_run = ElementOptMul;
|
|
|
|
|