|
|
|
@ -51,10 +51,36 @@ int ArithmeticCPUKernel::Init() {
|
|
|
|
|
|
|
|
|
|
int ArithmeticCPUKernel::ReSize() {
|
|
|
|
|
FreeTileData();
|
|
|
|
|
auto element_num = out_tensors_[0]->ElementsNum();
|
|
|
|
|
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
|
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
|
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
|
|
|
|
|
|
|
|
|
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
|
|
|
|
|
if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) {
|
|
|
|
|
switch (arithmeticParameter_->op_parameter_.type_) {
|
|
|
|
|
case PrimitiveType_Mul:
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false;
|
|
|
|
|
arithmetic_opt_run_ = ElementOptMul;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Add:
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false;
|
|
|
|
|
arithmetic_opt_run_ = ElementOptAdd;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Sub:
|
|
|
|
|
arithmeticParameter_->broadcasting_ = false;
|
|
|
|
|
arithmetic_opt_run_ = ElementOptSub;
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (arithmeticParameter_->broadcasting_) {
|
|
|
|
|
tile_data0_ = new float[arithmeticParameter_->out_elements_num_];
|
|
|
|
|
tile_data1_ = new float[arithmeticParameter_->out_elements_num_];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tile_data0_ = new float[element_num];
|
|
|
|
|
tile_data1_ = new float[element_num];
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -77,7 +103,17 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
|
|
|
|
|
if (arithmeticParameter_->broadcasting_) {
|
|
|
|
|
error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id,
|
|
|
|
|
output_data + stride * task_id, count);
|
|
|
|
|
|
|
|
|
|
} else if (arithmetic_opt_run_ != nullptr) {
|
|
|
|
|
if (arithmeticParameter_->in_elements_num0_ == 1) {
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id,
|
|
|
|
|
count, arithmeticParameter_);
|
|
|
|
|
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1, output_data + stride * task_id,
|
|
|
|
|
count, arithmeticParameter_);
|
|
|
|
|
} else {
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
|
|
|
|
|
output_data + stride * task_id, count, arithmeticParameter_);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
|
|
|
|
|
output_data + stride * task_id, count);
|
|
|
|
@ -104,6 +140,7 @@ int ArithmeticCPUKernel::Run() {
|
|
|
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (arithmeticParameter_->broadcasting_) {
|
|
|
|
|
auto input_data0 = reinterpret_cast<float *>(in_tensors_[0]->Data());
|
|
|
|
|
auto input_data1 = reinterpret_cast<float *>(in_tensors_[1]->Data());
|
|
|
|
|