From 25a2b8cd5b2c4b06d4003d00e5d213db61ce99af Mon Sep 17 00:00:00 2001 From: Zhang Qinghua Date: Fri, 4 Dec 2020 10:19:17 +0800 Subject: [PATCH] Never inline middle after block for control flow. --- .../ccsrc/frontend/optimizer/irpass/inline.h | 69 ++++++++++++------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h index f893f5df93..3a4277c61f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/inline.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/inline.h @@ -134,12 +134,24 @@ class InlinerBase : public AnfVisitor { std::vector args; (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); - // compare size to avoid the case that the function has default value after grad. + // Compare size to avoid the case that the function has default value after grad. // for which after renormalize, the function default value will be an input if (fg->parameters().size() != args.size()) { return nullptr; } + if (IsUniqueUse(nullptr, fg, nullptr)) { + // The other branch calling the last after block. + if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { + // Check if parameters' changed. + auto param_simplified_caller = SimplifyAfterParameter(fg, node, args); + if (param_simplified_caller != nullptr) { + return param_simplified_caller; + } + } + + // For the single used fg, including non-after and after not matched above, + // we move the whole fg nodes. if (use_move_) { auto mng = fg->manager(); MS_EXCEPTION_IF_NULL(mng); @@ -148,10 +160,20 @@ class InlinerBase : public AnfVisitor { mng->MoveAllCNodeDropGraph(fg, node->func_graph(), inputs[0]->scope()); return out_node; } - } else if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK) && GraphHasBranch(fg)) { - // Not to inline after block if it has switch call inside, to avoid switch expansion. - return TransformBranchCall(fg, node, args); + } else { + // We don't expand the middle multiple used after block, except the last one. + if (GraphHasBranch(fg)) { + return nullptr; + } + // Check if parameters' changed for the first met branch calling. + if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { + auto param_simplified_caller = SimplifyAfterParameter(fg, node, args); + if (param_simplified_caller != nullptr) { + return param_simplified_caller; + } + } } + // Or, just make a clone for not single used fg. return InlineClone(fg, node->func_graph(), args, inputs[0]->scope()); } @@ -183,37 +205,34 @@ class InlinerBase : public AnfVisitor { // For after block which contains branch call, delete the parameters which is not used. // In most cases, it may be a `Module` or other constant input. - AnfNodePtr TransformBranchCall(const FuncGraphPtr &fg, const AnfNodePtr &node, const std::vector &args) { + AnfNodePtr SimplifyAfterParameter(const FuncGraphPtr &fg, const AnfNodePtr &node, + const std::vector &args) { auto &fg_params = fg->parameters(); std::vector used_param_index; auto mng = fg->manager(); + bool should_simplify = false; for (size_t i = 0; i < fg_params.size(); i++) { if (mng->node_users()[fg_params[i]].size() != 0) { used_param_index.emplace_back(i); + } else { + MS_LOG(DEBUG) << "Not used parameter " << fg_params[i]->DebugString() << " for calling " << fg->ToString(); + should_simplify = true; } } - // If all parameters are used by cnodes - if (used_param_index.size() == fg_params.size()) { + if (!should_simplify) { return nullptr; } - if (transformed_branch_chache_.find(fg) == transformed_branch_chache_.end()) { - MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString(); - // clone a new graph and ignore the not used parameters - FuncGraphPtr new_fg = TransformableClone(fg); - auto &new_fg_params = new_fg->parameters(); - std::vector new_params; - std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params), - [&new_fg_params](size_t i) { return new_fg_params[i]; }); - new_fg->set_parameters(new_params); - // New func graph must set FUNC_GRAPH_FLAG_AFTER_BLOCK flag otherwise the new graph will be inlined. - new_fg->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true); - // Add new graph to the cache to improve perfomance when call HasBranchCall. - graph_branch_cache_[new_fg] = true; - // If a graph be called at two or more locations, it should not be cloned once again, so add it to the cache. - transformed_branch_chache_[fg] = new_fg; - } + MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString(); + // Clone a new graph and ignore the not used parameters + auto new_fg = TransformableClone(fg); + auto &new_fg_params = new_fg->parameters(); + std::vector new_params; + std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params), + [&new_fg_params](size_t i) { return new_fg_params[i]; }); + new_fg->set_parameters(new_params); + std::vector node_inputs; - node_inputs.push_back(NewValueNode(transformed_branch_chache_[fg])); + node_inputs.push_back(NewValueNode(new_fg)); std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs), [&args](size_t i) { return args[i]; }); return node->func_graph()->NewCNode(node_inputs); @@ -273,8 +292,6 @@ class InlinerBase : public AnfVisitor { bool use_move_; std::vector> criterions_; std::unordered_map graph_branch_cache_; - // Key is the old func graph, and the value is the new func_graph - std::unordered_map transformed_branch_chache_; }; bool IsUniqueUse(InlinerBase *, const FuncGraphPtr &fg, const AnfNodePtr &) {