!7926 [ME][OptPass]fix bug when eliminate unused parameter in 'inline' pass

From: @chenfei52
Reviewed-by: 
Signed-off-by:
pull/7926/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 244b7034e8

@ -174,8 +174,9 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNod
MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString() MS_LOG(DEBUG) << "Decrease input:" << next_node->DebugString() << ",node:" << node->DebugString()
<< ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second; << ",num: " << node_input_num_[next_node] << ",decrease num:" << output_edge.second;
if (node_input_num_[next_node] < output_edge.second) { if (node_input_num_[next_node] < output_edge.second) {
MS_LOG(EXCEPTION) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node] MS_LOG(DEBUG) << "Input node:" << next_node->DebugString() << ",node_output_num" << node_input_num_[next_node]
<< ",depend edge:" << output_edge.second; << ",depend edge:" << output_edge.second;
continue;
} }
node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second; node_input_num_[next_node] = node_input_num_[next_node] - output_edge.second;
// allreduce first // allreduce first

@ -162,15 +162,16 @@ class InlinerBase : public AnfVisitor {
if (fg->parameters().size() != args.size()) { if (fg->parameters().size() != args.size()) {
return nullptr; return nullptr;
} }
auto is_unique_use = IsUniqueUse(fg, nullptr);
// Not to inline after block if it has switch call inside, to avoid switch expansion. // Not to inline after block if it has switch call inside, to avoid switch expansion.
if (fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) { if (!is_unique_use && fg->has_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK)) {
auto has_branch_call = GraphHasBranch(fg); auto has_branch_call = GraphHasBranch(fg);
if (has_branch_call) { if (has_branch_call) {
return TransformBranchCall(fg, node, args); return TransformBranchCall(fg, node, args);
} }
} }
if (use_move_ && IsUniqueUse(fg, nullptr)) { if (use_move_ && is_unique_use) {
auto mng = fg->manager(); auto mng = fg->manager();
MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(mng);
ReplaceParams(mng, args, fg); ReplaceParams(mng, args, fg);
@ -218,7 +219,11 @@ class InlinerBase : public AnfVisitor {
used_param_index.emplace_back(i); used_param_index.emplace_back(i);
} }
} }
if (used_param_index.size() != fg_params.size()) { // If all parameters are used by cnodes
if (used_param_index.size() == fg_params.size()) {
return nullptr;
}
if (transformed_branch_chache_.find(fg) == transformed_branch_chache_.end()) {
MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString(); MS_LOG(DEBUG) << "Parameter not used found for graph :" << fg->ToString();
// clone a new graph and ignore the not used parameters // clone a new graph and ignore the not used parameters
FuncGraphPtr new_fg = TransformableClone(fg); FuncGraphPtr new_fg = TransformableClone(fg);
@ -227,13 +232,18 @@ class InlinerBase : public AnfVisitor {
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params), std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(new_params),
[&new_fg_params](size_t i) { return new_fg_params[i]; }); [&new_fg_params](size_t i) { return new_fg_params[i]; });
new_fg->set_parameters(new_params); new_fg->set_parameters(new_params);
std::vector<AnfNodePtr> node_inputs; // New func graph must set FUNC_GRAPH_FLAG_AFTER_BLOCK flag otherwise the new graph will be inlined.
node_inputs.push_back(NewValueNode(new_fg)); new_fg->set_flag(FUNC_GRAPH_FLAG_AFTER_BLOCK, true);
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs), // Add new graph to the cache to improve perfomance when call HasBranchCall.
[&args](size_t i) { return args[i]; }); graph_branch_cache_[new_fg] = true;
return node->func_graph()->NewCNode(node_inputs); // If a graph be called at two or more locations, it should not be cloned once again, so add it to the cache.
transformed_branch_chache_[fg] = new_fg;
} }
return nullptr; std::vector<AnfNodePtr> node_inputs;
node_inputs.push_back(NewValueNode(transformed_branch_chache_[fg]));
std::transform(used_param_index.begin(), used_param_index.end(), std::back_inserter(node_inputs),
[&args](size_t i) { return args[i]; });
return node->func_graph()->NewCNode(node_inputs);
} }
// This is a try-best algorithm to find a graph which may generate branch call. // This is a try-best algorithm to find a graph which may generate branch call.
@ -290,6 +300,8 @@ class InlinerBase : public AnfVisitor {
bool use_move_; bool use_move_;
std::vector<std::pair<CriterionFuncType, bool>> criterions_; std::vector<std::pair<CriterionFuncType, bool>> criterions_;
std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_; std::unordered_map<FuncGraphPtr, bool> graph_branch_cache_;
// Key is the old func graph, and the value is the new func_graph
std::unordered_map<FuncGraphPtr, FuncGraphPtr> transformed_branch_chache_;
}; };
class Inliner : public InlinerBase { class Inliner : public InlinerBase {

Loading…
Cancel
Save