diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 978e4dc5d3..d9f3afb06f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -65,33 +65,57 @@ AnalysisContextPtr BaseFuncGraphEvaluator::MakeContext(const AnalysisEnginePtr & return context; } -static std::vector FastShadowSort(const AnfNodePtr &ret_node) { - auto current_func_graph = ret_node->func_graph(); +// Return CNodes set that may contain duplicates from a DAG function graph. +// Should check the results' second item as ignored flag before use them, to avoid processing repeatedly. +static inline std::vector> SortCNodesContainDup(const AnfNodePtr &root_node) { + auto current_func_graph = root_node->func_graph(); MS_EXCEPTION_IF_NULL(current_func_graph); - std::vector sorted_nodes; - auto seen = NewSeenGeneration(); + std::vector> sorted_nodes; // Record {node, ignored_flag}. + std::unordered_map checked_cnodes; // Record {node, position_in_sorted_nodes} std::size_t index = 0; - sorted_nodes.emplace_back(ret_node); + sorted_nodes.emplace_back(std::pair(root_node, false)); while (index < sorted_nodes.size()) { - auto current = sorted_nodes[index]; - index++; + auto current = sorted_nodes[index].first; MS_EXCEPTION_IF_NULL(current); - if (current->isa()) { + auto ignored_flag = sorted_nodes[index].second; + if (!ignored_flag && current->isa()) { auto &inputs = current->cast()->inputs(); - for (auto it = inputs.begin(); it != inputs.end(); it++) { + for (auto it = inputs.crbegin(); it != inputs.crend(); it++) { AnfNodePtr input = *it; - if (input != nullptr && input->isa() && input->seen_ != seen && - input->func_graph() == current_func_graph) { - sorted_nodes.emplace_back(input); - input->seen_ = seen; + if (input == nullptr || !input->isa() || input->func_graph() != current_func_graph) { + continue; + } + auto checked_item = checked_cnodes.find(input); + if (checked_item == checked_cnodes.end()) { // Not met before. + checked_cnodes.insert({input, sorted_nodes.size()}); + sorted_nodes.emplace_back(std::pair(input, false)); + } else { // Checked, should update flag and new position. + auto pos = checked_item->second; + sorted_nodes[pos].second = true; // Set ignore flag. + checked_cnodes[input] = sorted_nodes.size(); // Update a new position. + sorted_nodes.emplace_back(std::pair(input, false)); // Insert duplicate node into new position. } } } + index++; } return sorted_nodes; } +// Return CNodes set that root at the bottom. +static inline std::vector SortReverseCNodes(const AnfNodePtr &root_node) { + std::vector res; + auto nodes_with_flag = SortCNodesContainDup(root_node); + for (auto it = nodes_with_flag.crbegin(); it != nodes_with_flag.crend(); it++) { + if (it->second) { // Check ignored flag. + continue; + } + res.emplace_back(it->first); + } + return res; +} + EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list) { FuncGraphPtr fg = GetFuncGraph(engine, args_spec_list); MS_EXCEPTION_IF_NULL(fg); @@ -123,9 +147,12 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr << MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH) << ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; } - std::vector nodes = FastShadowSort(func_node); - for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { - const auto &node = *it; + auto nodes_with_flag = SortCNodesContainDup(func_node); + for (auto it = nodes_with_flag.crbegin(); it != nodes_with_flag.crend(); it++) { + if (it->second) { // Check ignored flag. + continue; + } + const auto &node = it->first; AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() << ", node_conf: " << node_conf->ToString();