!4113 Add fused_activation function for Sub, Add, Mul and Div op

Merge pull request !4113 from wangminggui/master
pull/4113/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit f8e4ab86a2

@ -384,19 +384,19 @@ table Eltwise {
}
table Add {
activationType : ActivationType;
activationType: ActivationType = 0;
}
table Sub {
activationType : ActivationType;
activationType: ActivationType = 0;
}
table Mul {
activationType : ActivationType;
activationType: ActivationType = 0;
}
table Div {
activationType : ActivationType;
activationType: ActivationType = 0;
}
table AddGrad {

@ -510,6 +510,23 @@ OpParameter *PopulateArithmetic(const lite::Primitive *primitive) {
arithmetic_param->op_parameter_.type_ = primitive->Type();
arithmetic_param->broadcasting_ = ((lite::Arithmetic *)primitive)->Broadcasting();
arithmetic_param->ndim_ = ((lite::Arithmetic *)primitive)->NDims();
switch (primitive->Type()) {
case schema::PrimitiveType_Add:
arithmetic_param->activation_type_ = primitive->Value()->value_as_Add()->activationType();
break;
case schema::PrimitiveType_Sub:
arithmetic_param->activation_type_ = primitive->Value()->value_as_Sub()->activationType();
break;
case schema::PrimitiveType_Mul:
arithmetic_param->activation_type_ = primitive->Value()->value_as_Mul()->activationType();
break;
case schema::PrimitiveType_Div:
arithmetic_param->activation_type_ = primitive->Value()->value_as_Div()->activationType();
break;
default:
arithmetic_param->activation_type_ = 0;
break;
}
auto tmp_shape = ((lite::Arithmetic *)primitive)->InShape0();
(void)memcpy(arithmetic_param->in_shape0_, static_cast<void *>(tmp_shape.data()), tmp_shape.size() * sizeof(int));
tmp_shape = ((lite::Arithmetic *)primitive)->InShape1();

@ -56,31 +56,28 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
auto input1_data1 = reinterpret_cast<float *>(inputs_[1]->Data());
auto output_data = reinterpret_cast<float *>(outputs_[0]->Data());
auto element_num = outputs_[0]->ElementsNum();
if (arithmeticParameter_->broadcasting_) {
if (arithmetic_broadcast_run_ == nullptr) {
MS_LOG(ERROR) << "broadcasting_run function is nullptr!";
return RET_ERROR;
}
MS_ASSERT(thread_count_ != 0);
int stride = UP_DIV(element_num, thread_count_);
int count = MSMIN(stride, element_num - stride * task_id);
int error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id,
if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
return RET_ERROR;
}
int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) {
error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id,
output_data + stride * task_id, count);
if (error_code != RET_OK) {
return RET_ERROR;
} else {
error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
output_data + stride * task_id, count);
}
} else if (arithmetic_run_ != nullptr) {
int error_code = arithmetic_run_(input0_data, input1_data1, output_data, element_num);
if (error_code != RET_OK) {
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
return RET_ERROR;
}
return RET_OK;
}

@ -50,22 +50,59 @@ class ArithmeticCPUKernel : public LiteKernel {
ArithmeticCPUKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs, const lite::Context *ctx)
: LiteKernel(parameter, inputs, outputs), thread_count_(ctx->thread_num_) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
switch (parameter->type_) {
case PrimitiveType_Mul:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementMulRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementMulRelu6;
break;
default:
arithmetic_run_ = ElementMul;
arithmetic_broadcast_run_ = BroadcastMul;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementAddRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementAddRelu6;
break;
default:
arithmetic_run_ = ElementAdd;
arithmetic_broadcast_run_ = BroadcastAdd;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementSubRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementSubRelu6;
break;
default:
arithmetic_run_ = ElementSub;
arithmetic_broadcast_run_ = BroadcastSub;
break;
}
break;
case PrimitiveType_Div:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementDivRelu6;
break;
default:
arithmetic_run_ = ElementDiv;
arithmetic_broadcast_run_ = BroadcastDiv;
break;
}
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd;
@ -125,7 +162,6 @@ class ArithmeticCPUKernel : public LiteKernel {
arithmetic_broadcast_run_ = nullptr;
break;
}
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~ArithmeticCPUKernel() override;

@ -27,6 +27,7 @@ struct ArithmeticParameter {
OpParameter op_parameter_;
bool broadcasting_;
size_t ndim_;
int activation_type_;
int in_shape0_[5];
int in_shape1_[5];
int out_shape_[5];
@ -49,4 +50,3 @@ void TileDimensionsInt8(int8_t *data0, int8_t *data1, int8_t *tile_data0, int8_t
ArithmeticParameter *param);
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_ARITHMETIC_COMMON_H_

@ -47,7 +47,7 @@ inline int Relu6(const float *src, int length, float *dst) {
inline int LRelu(const float *src, int length, float *dst, float alpha) {
for (int i = 0; i < length; ++i) {
dst[i] = src[i] > (src[i] * alpha) ? src[i] : (src[i] * alpha);
dst[i] = src[i] > 0 ? src[i] : (src[i] * alpha);
}
return NNACL_OK;
}

@ -24,20 +24,28 @@
#include "src/runtime/kernel/arm/nnacl/errorcode.h"
int ElementMul(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu(float *input0, float *input1, float *output, int element_size);
int ElementMulRelu6(float *input0, float *input1, float *output, int element_size);
int BroadcastMul(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);
int ElementAdd(float *input0, float *input1, float *output, int element_size);
int ElementAddRelu(float *input0, float *input1, float *output, int element_size);
int ElementAddRelu6(float *input0, float *input1, float *output, int element_size);
int BroadcastAdd(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);
int BroadcastAddInt8(int8_t *input0, int8_t *input1, int8_t *tile_input0, int8_t *tile_input1, int8_t *output,
int element_size, ArithmeticParameter *param);
int ElementSub(float *input0, float *input1, float *output, int element_size);
int ElementSubRelu(float *input0, float *input1, float *output, int element_size);
int ElementSubRelu6(float *input0, float *input1, float *output, int element_size);
int BroadcastSub(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);
int ElementDiv(float *input0, float *input1, float *output, int element_size);
int ElementDivRelu(float *input0, float *input1, float *output, int element_size);
int ElementDivRelu6(float *input0, float *input1, float *output, int element_size);
int BroadcastDiv(float *input0, float *input1, float *tile_input0, float *tile_input1, float *output, int element_size,
ArithmeticParameter *param);

Loading…
Cancel
Save