fix conv_conv fusion bug

pull/12463/head
zengxianglong 4 years ago
parent d4abe53f34
commit 94e2ffcafe

@ -156,6 +156,56 @@ const BaseRef ConvConvFusion::DefinePattern() const {
return VectorRef({down_conv_var, up_conv_var, down_weight_var, down_bias_var});
}
void ReplaceParametersAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &up_conv_cnode,
const CNodePtr &down_conv_cnode) {
auto down_weight_parameter = down_conv_cnode->input(kConvWeightIndex)->cast<ParameterPtr>();
auto up_weight_parameter = up_conv_cnode->input(kConvWeightIndex)->cast<ParameterPtr>();
auto new_weight_paramter = func_graph->add_parameter();
if (GenNewConvWeight(down_weight_parameter, up_weight_parameter, new_weight_paramter) != RET_OK) {
MS_LOG(ERROR) << "GenNewConvWeight failed.";
return;
}
auto manager = func_graph->manager();
manager->Replace(down_weight_parameter, new_weight_paramter);
// whether up conv node has bias
if (up_conv_cnode->inputs().size() == kConvWithBiasLen) {
ParameterPtr down_bias_parameter;
if (down_conv_cnode->inputs().size() == kConvWithBiasLen) {
down_bias_parameter = down_conv_cnode->input(kConvBiasIndex)->cast<ParameterPtr>();
}
auto up_bias_parameter = up_conv_cnode->input(kConvBiasIndex)->cast<ParameterPtr>();
auto new_bias_parameter = func_graph->add_parameter();
if (GenNewConvBias(down_bias_parameter, down_weight_parameter, up_bias_parameter, new_bias_parameter) != RET_OK) {
MS_LOG(ERROR) << "GenNewConvBias failed.";
return;
}
if (down_conv_cnode->inputs().size() == kConvWithBiasLen) {
manager->Replace(down_bias_parameter, new_bias_parameter);
} else {
down_conv_cnode->add_input(new_bias_parameter);
}
} else {
MS_LOG(INFO) << "up conv node has no bias,no need replace bias.";
}
MS_LOG(INFO) << "fusion node success:" << down_conv_cnode->fullname_with_scope();
// delete up conv node
manager->Replace(up_conv_cnode, up_conv_cnode->input(1));
return;
}
bool IsPrimitiveProper(const CNodePtr &up_conv_cnode, const CNodePtr &down_conv_cnode) {
auto down_primitive = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(down_conv_cnode->input(0));
auto down_conv_primitive = utils::cast<std::shared_ptr<lite::Conv2D>>(down_primitive);
auto up_primitive = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(up_conv_cnode->input(0));
auto up_conv_primitive = utils::cast<std::shared_ptr<lite::Conv2D>>(up_primitive);
return up_conv_primitive != nullptr &&
up_conv_primitive->GetActivationType() == schema::ActivationType_NO_ACTIVATION &&
up_conv_primitive->GetGroup() == 1 && down_conv_primitive->GetGroup() == 1 &&
up_conv_primitive->GetKernelW() == down_conv_primitive->GetKernelW() &&
up_conv_primitive->GetKernelH() == down_conv_primitive->GetKernelH() &&
up_conv_primitive->GetPadMode() == down_conv_primitive->GetPadMode();
}
// conv->conv1x1 fusion conv (w1x+b)w2+c = (w1*w2)*x+(w2*b+c)
const AnfNodePtr ConvConvFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
@ -198,54 +248,18 @@ const AnfNodePtr ConvConvFusion::Process(const FuncGraphPtr &func_graph, const A
return nullptr;
}
if (cin0 * (cout1 - cout0) > cout0 * cout1) {
MS_LOG(INFO) << "conv_conv_fusion up conv and down conv node channel requirment not fit";
MS_LOG(INFO) << "conv_conv_fusion up conv and down conv node channel requirement not fit";
return nullptr;
}
// multi output need skip
if (IsMultiOutputTensors(func_graph, up_conv_cnode)) {
return nullptr;
}
auto down_primitive = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(down_conv_cnode->input(0));
auto down_conv_primitive = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(down_primitive);
auto up_primitive = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(up_conv_cnode->input(0));
auto up_conv_primitive = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(up_primitive);
// up conv node must no activation
if (up_conv_primitive == nullptr || up_conv_primitive->GetActivationType() != schema::ActivationType_NO_ACTIVATION) {
return nullptr;
}
if (up_conv_primitive->GetGroup() != 1 || down_conv_primitive->GetGroup() != 1) {
// up conv node must no activation, and attributes should be proper
if (!IsPrimitiveProper(up_conv_cnode, down_conv_cnode)) {
return nullptr;
}
auto new_weight_paramter = func_graph->add_parameter();
if (GenNewConvWeight(down_weight_parameter, up_weight_parameter, new_weight_paramter) != RET_OK) {
MS_LOG(ERROR) << "GenNewConvWeight failed.";
return nullptr;
}
auto manager = func_graph->manager();
manager->Replace(down_weight_parameter, new_weight_paramter);
// up conv node no bias
if (up_conv_cnode->inputs().size() == kConvWithBiasLen) {
ParameterPtr down_bias_parameter;
if (down_conv_cnode->inputs().size() == kConvWithBiasLen) {
down_bias_parameter = down_conv_cnode->input(kConvBiasIndex)->cast<ParameterPtr>();
}
auto up_bias_parameter = up_conv_cnode->input(kConvBiasIndex)->cast<ParameterPtr>();
auto new_bias_paramter = func_graph->add_parameter();
if (GenNewConvBias(down_bias_parameter, down_weight_parameter, up_bias_parameter, new_bias_paramter) != RET_OK) {
MS_LOG(ERROR) << "GenNewConvBias failed.";
return nullptr;
}
if (down_conv_cnode->inputs().size() == kConvWithBiasLen) {
manager->Replace(down_bias_parameter, new_bias_paramter);
} else {
down_conv_cnode->add_input(new_bias_paramter);
}
} else {
MS_LOG(INFO) << "up conv node has no bias,no need replace bias.";
}
MS_LOG(INFO) << "fusion node success:" << down_conv_cnode->fullname_with_scope();
// delete up conv node
manager->Replace(up_conv_cnode, up_conv_cnode->input(1));
ReplaceParametersAndNodes(func_graph, up_conv_cnode, down_conv_cnode);
return nullptr;
}
} // namespace mindspore::opt

Loading…
Cancel
Save