diff --git a/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc index 1419e50cb9..f3a406e515 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/fusion/mul_add_fusion_pass.cc @@ -31,6 +31,7 @@ namespace mindspore { namespace lite { #define MUL_ADD_MATCH_PATH_LEN 2 #define ADD_OP_BIAS_INDEX 1 +#define MUL_OP_INPUT_INDEX 0 #define MUL_OP_BIAS_INDEX 1 #define MUL_OP_INPUT_NUM 2 #define ADD_OP_INPUT_NUM 2 @@ -60,6 +61,23 @@ STATUS MulAddFusionPass::DefinePattern() { return RET_OK; } +bool ScaleInputShapeValid(const std::vector &input_shape, const std::vector &scale_shape, + const std::vector &offset_shape) { + if (input_shape.size() < scale_shape.size() || scale_shape.size() == 0) { + return false; + } + size_t rank_diff = input_shape.size() - scale_shape.size(); + for (size_t i = 0; i < scale_shape.size(); ++i) { + if (input_shape[i + rank_diff] != scale_shape[i]) { + return false; + } + } + if (scale_shape != offset_shape) { + return false; + } + return true; +} + STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName, std::unordered_map> &matchedPath) { MS_ASSERT(graph != nullptr); @@ -79,7 +97,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX)); MS_ASSERT(mulNodeBiasTensor != nullptr); - if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode || mulNodeBiasTensor->dims.size() == 4) { + if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) { // dont fusion, return return RET_OK; } @@ -96,7 +114,11 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN // dont fusion, return return RET_OK; } - + // scale requires scale shape tail sub of input shape, scale shape same as bias shape + const auto &mulNodeInputTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_INPUT_INDEX)); + if (!ScaleInputShapeValid(mulNodeInputTensor->dims, mulNodeBiasTensor->dims, addNodeBiasTensor->dims)) { + return RET_OK; + } // convert mul and add to scale auto status = AddNewScaleNode(graph, mulNode, addNode.get(), addNodeInputIndex.at(ADD_OP_BIAS_INDEX)); if (RET_OK != status) {