From 94e2ffcafec9c062d936543b03e381463c08c171 Mon Sep 17 00:00:00 2001 From: zengxianglong Date: Sat, 20 Feb 2021 18:06:31 +0800 Subject: [PATCH] fix conv_conv fusion bug --- .../optimizer/fusion/conv_conv_fusion.cc | 94 +++++++++++-------- 1 file changed, 54 insertions(+), 40 deletions(-) diff --git a/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc index 681e0b3795..a312d8cd9d 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_conv_fusion.cc @@ -156,6 +156,56 @@ const BaseRef ConvConvFusion::DefinePattern() const { return VectorRef({down_conv_var, up_conv_var, down_weight_var, down_bias_var}); } +void ReplaceParametersAndNodes(const FuncGraphPtr &func_graph, const CNodePtr &up_conv_cnode, + const CNodePtr &down_conv_cnode) { + auto down_weight_parameter = down_conv_cnode->input(kConvWeightIndex)->cast(); + auto up_weight_parameter = up_conv_cnode->input(kConvWeightIndex)->cast(); + auto new_weight_paramter = func_graph->add_parameter(); + if (GenNewConvWeight(down_weight_parameter, up_weight_parameter, new_weight_paramter) != RET_OK) { + MS_LOG(ERROR) << "GenNewConvWeight failed."; + return; + } + auto manager = func_graph->manager(); + manager->Replace(down_weight_parameter, new_weight_paramter); + // whether up conv node has bias + if (up_conv_cnode->inputs().size() == kConvWithBiasLen) { + ParameterPtr down_bias_parameter; + if (down_conv_cnode->inputs().size() == kConvWithBiasLen) { + down_bias_parameter = down_conv_cnode->input(kConvBiasIndex)->cast(); + } + auto up_bias_parameter = up_conv_cnode->input(kConvBiasIndex)->cast(); + auto new_bias_parameter = func_graph->add_parameter(); + if (GenNewConvBias(down_bias_parameter, down_weight_parameter, up_bias_parameter, new_bias_parameter) != RET_OK) { + MS_LOG(ERROR) << "GenNewConvBias failed."; + return; + } + if (down_conv_cnode->inputs().size() == kConvWithBiasLen) { + manager->Replace(down_bias_parameter, new_bias_parameter); + } else { + down_conv_cnode->add_input(new_bias_parameter); + } + } else { + MS_LOG(INFO) << "up conv node has no bias,no need replace bias."; + } + MS_LOG(INFO) << "fusion node success:" << down_conv_cnode->fullname_with_scope(); + // delete up conv node + manager->Replace(up_conv_cnode, up_conv_cnode->input(1)); + return; +} + +bool IsPrimitiveProper(const CNodePtr &up_conv_cnode, const CNodePtr &down_conv_cnode) { + auto down_primitive = GetValueNode>(down_conv_cnode->input(0)); + auto down_conv_primitive = utils::cast>(down_primitive); + auto up_primitive = GetValueNode>(up_conv_cnode->input(0)); + auto up_conv_primitive = utils::cast>(up_primitive); + return up_conv_primitive != nullptr && + up_conv_primitive->GetActivationType() == schema::ActivationType_NO_ACTIVATION && + up_conv_primitive->GetGroup() == 1 && down_conv_primitive->GetGroup() == 1 && + up_conv_primitive->GetKernelW() == down_conv_primitive->GetKernelW() && + up_conv_primitive->GetKernelH() == down_conv_primitive->GetKernelH() && + up_conv_primitive->GetPadMode() == down_conv_primitive->GetPadMode(); +} + // conv->conv1x1 fusion conv (w1x+b)w2+c = (w1*w2)*x+(w2*b+c) const AnfNodePtr ConvConvFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { @@ -198,54 +248,18 @@ const AnfNodePtr ConvConvFusion::Process(const FuncGraphPtr &func_graph, const A return nullptr; } if (cin0 * (cout1 - cout0) > cout0 * cout1) { - MS_LOG(INFO) << "conv_conv_fusion up conv and down conv node channel requirment not fit"; + MS_LOG(INFO) << "conv_conv_fusion up conv and down conv node channel requirement not fit"; return nullptr; } // multi output need skip if (IsMultiOutputTensors(func_graph, up_conv_cnode)) { return nullptr; } - auto down_primitive = GetValueNode>(down_conv_cnode->input(0)); - auto down_conv_primitive = utils::cast>(down_primitive); - auto up_primitive = GetValueNode>(up_conv_cnode->input(0)); - auto up_conv_primitive = utils::cast>(up_primitive); - // up conv node must no activation - if (up_conv_primitive == nullptr || up_conv_primitive->GetActivationType() != schema::ActivationType_NO_ACTIVATION) { - return nullptr; - } - if (up_conv_primitive->GetGroup() != 1 || down_conv_primitive->GetGroup() != 1) { + // up conv node must no activation, and attributes should be proper + if (!IsPrimitiveProper(up_conv_cnode, down_conv_cnode)) { return nullptr; } - auto new_weight_paramter = func_graph->add_parameter(); - if (GenNewConvWeight(down_weight_parameter, up_weight_parameter, new_weight_paramter) != RET_OK) { - MS_LOG(ERROR) << "GenNewConvWeight failed."; - return nullptr; - } - auto manager = func_graph->manager(); - manager->Replace(down_weight_parameter, new_weight_paramter); - // up conv node no bias - if (up_conv_cnode->inputs().size() == kConvWithBiasLen) { - ParameterPtr down_bias_parameter; - if (down_conv_cnode->inputs().size() == kConvWithBiasLen) { - down_bias_parameter = down_conv_cnode->input(kConvBiasIndex)->cast(); - } - auto up_bias_parameter = up_conv_cnode->input(kConvBiasIndex)->cast(); - auto new_bias_paramter = func_graph->add_parameter(); - if (GenNewConvBias(down_bias_parameter, down_weight_parameter, up_bias_parameter, new_bias_paramter) != RET_OK) { - MS_LOG(ERROR) << "GenNewConvBias failed."; - return nullptr; - } - if (down_conv_cnode->inputs().size() == kConvWithBiasLen) { - manager->Replace(down_bias_parameter, new_bias_paramter); - } else { - down_conv_cnode->add_input(new_bias_paramter); - } - } else { - MS_LOG(INFO) << "up conv node has no bias,no need replace bias."; - } - MS_LOG(INFO) << "fusion node success:" << down_conv_cnode->fullname_with_scope(); - // delete up conv node - manager->Replace(up_conv_cnode, up_conv_cnode->input(1)); + ReplaceParametersAndNodes(func_graph, up_conv_cnode, down_conv_cnode); return nullptr; } } // namespace mindspore::opt