fix arithmetic

pull/5371/head
sunsuodong 5 years ago
parent 52a7db8180
commit 0fce8a9786

@ -104,6 +104,7 @@ int ArithmeticFP16CPUKernel::Init() {
arithmetic_run_ = ElementSubFp16; arithmetic_run_ = ElementSubFp16;
break; break;
} }
break;
case PrimitiveType_Div: case PrimitiveType_Div:
switch (arithmeticParameter_->activation_type_) { switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU: case schema::ActivationType_RELU:
@ -116,32 +117,46 @@ int ArithmeticFP16CPUKernel::Init() {
arithmetic_run_ = ElementDivFp16; arithmetic_run_ = ElementDivFp16;
break; break;
} }
break;
case PrimitiveType_FloorMod: case PrimitiveType_FloorMod:
arithmetic_run_ = ElementFloorModFp16; arithmetic_run_ = ElementFloorModFp16;
break;
case PrimitiveType_FloorDiv: case PrimitiveType_FloorDiv:
arithmetic_run_ = ElementFloorDivFp16; arithmetic_run_ = ElementFloorDivFp16;
break;
case PrimitiveType_LogicalAnd: case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAndFp16; arithmetic_run_ = ElementLogicalAndFp16;
break;
case PrimitiveType_LogicalOr: case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOrFp16; arithmetic_run_ = ElementLogicalOrFp16;
break;
case PrimitiveType_SquaredDifference: case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifferenceFp16; arithmetic_run_ = ElementSquaredDifferenceFp16;
break;
case PrimitiveType_Maximum: case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximumFp16; arithmetic_run_ = ElementMaximumFp16;
break;
case PrimitiveType_Minimum: case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimumFp16; arithmetic_run_ = ElementMinimumFp16;
break;
case PrimitiveType_NotEqual: case PrimitiveType_NotEqual:
arithmetic_run_ = ElementNotEqualFp16; arithmetic_run_ = ElementNotEqualFp16;
break;
case PrimitiveType_Equal: case PrimitiveType_Equal:
arithmetic_run_ = ElementEqualFp16; arithmetic_run_ = ElementEqualFp16;
break;
case PrimitiveType_Less: case PrimitiveType_Less:
arithmetic_run_ = ElementLessFp16; arithmetic_run_ = ElementLessFp16;
break;
case PrimitiveType_LessEqual: case PrimitiveType_LessEqual:
arithmetic_run_ = ElementLessEqual; arithmetic_run_ = ElementLessEqual;
break;
case PrimitiveType_Greater: case PrimitiveType_Greater:
arithmetic_run_ = ElementGreaterFp16; arithmetic_run_ = ElementGreaterFp16;
break;
case PrimitiveType_GreaterEqual: case PrimitiveType_GreaterEqual:
arithmetic_run_ = ElementGreaterEqualFp16; arithmetic_run_ = ElementGreaterEqualFp16;
break;
default: default:
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
arithmetic_run_ = nullptr; arithmetic_run_ = nullptr;
@ -219,42 +234,55 @@ int ArithmeticFP16CPUKernel::ReSize() {
case PrimitiveType_FloorMod: case PrimitiveType_FloorMod:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptFloorModFp16; arithmetic_opt_run_ = ElementOptFloorModFp16;
break;
case PrimitiveType_FloorDiv: case PrimitiveType_FloorDiv:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptFloorDivFp16; arithmetic_opt_run_ = ElementOptFloorDivFp16;
break;
case PrimitiveType_LogicalAnd: case PrimitiveType_LogicalAnd:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLogicalAndFp16; arithmetic_opt_run_ = ElementOptLogicalAndFp16;
break;
case PrimitiveType_LogicalOr: case PrimitiveType_LogicalOr:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLogicalOrFp16; arithmetic_opt_run_ = ElementOptLogicalOrFp16;
break;
case PrimitiveType_SquaredDifference: case PrimitiveType_SquaredDifference:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptSquaredDifferenceFp16; arithmetic_opt_run_ = ElementOptSquaredDifferenceFp16;
break;
case PrimitiveType_Maximum: case PrimitiveType_Maximum:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMaximumFp16; arithmetic_opt_run_ = ElementOptMaximumFp16;
break;
case PrimitiveType_Minimum: case PrimitiveType_Minimum:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMinimumFp16; arithmetic_opt_run_ = ElementOptMinimumFp16;
break;
case PrimitiveType_NotEqual: case PrimitiveType_NotEqual:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptNotEqualFp16; arithmetic_opt_run_ = ElementOptNotEqualFp16;
break;
case PrimitiveType_Equal: case PrimitiveType_Equal:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptEqualFp16; arithmetic_opt_run_ = ElementOptEqualFp16;
break;
case PrimitiveType_Less: case PrimitiveType_Less:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLessFp16; arithmetic_opt_run_ = ElementOptLessFp16;
break;
case PrimitiveType_LessEqual: case PrimitiveType_LessEqual:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptLessEqualFp16; arithmetic_opt_run_ = ElementOptLessEqualFp16;
break;
case PrimitiveType_Greater: case PrimitiveType_Greater:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptGreaterFp16; arithmetic_opt_run_ = ElementOptGreaterFp16;
break;
case PrimitiveType_GreaterEqual: case PrimitiveType_GreaterEqual:
arithmeticParameter_->broadcasting_ = false; arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptGreaterEqualFp16; arithmetic_opt_run_ = ElementOptGreaterEqualFp16;
break;
default: default:
break; break;
} }

Loading…
Cancel
Save