!6837 mul add fusion broaden condition

Merge pull request !6837 from zhaozhenlong/lite/issue/mul_add_fusion_cond_broaden
pull/6837/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 16b77da7dd

@ -31,6 +31,7 @@ namespace mindspore {
namespace lite {
#define MUL_ADD_MATCH_PATH_LEN 2
#define ADD_OP_BIAS_INDEX 1
#define MUL_OP_INPUT_INDEX 0
#define MUL_OP_BIAS_INDEX 1
#define MUL_OP_INPUT_NUM 2
#define ADD_OP_INPUT_NUM 2
@ -60,6 +61,23 @@ STATUS MulAddFusionPass::DefinePattern() {
return RET_OK;
}
bool ScaleInputShapeValid(const std::vector<int> &input_shape, const std::vector<int> &scale_shape,
const std::vector<int> &offset_shape) {
if (input_shape.size() < scale_shape.size() || scale_shape.size() == 0) {
return false;
}
size_t rank_diff = input_shape.size() - scale_shape.size();
for (size_t i = 0; i < scale_shape.size(); ++i) {
if (input_shape[i + rank_diff] != scale_shape[i]) {
return false;
}
}
if (scale_shape != offset_shape) {
return false;
}
return true;
}
STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName,
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
MS_ASSERT(graph != nullptr);
@ -79,7 +97,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
MS_ASSERT(mulNodeBiasTensor != nullptr);
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode || mulNodeBiasTensor->dims.size() == 4) {
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) {
// dont fusion, return
return RET_OK;
}
@ -96,7 +114,11 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
// dont fusion, return
return RET_OK;
}
// scale requires scale shape tail sub of input shape, scale shape same as bias shape
const auto &mulNodeInputTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_INPUT_INDEX));
if (!ScaleInputShapeValid(mulNodeInputTensor->dims, mulNodeBiasTensor->dims, addNodeBiasTensor->dims)) {
return RET_OK;
}
// convert mul and add to scale
auto status = AddNewScaleNode(graph, mulNode, addNode.get(), addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
if (RET_OK != status) {

Loading…
Cancel
Save