diff --git a/mindspore/ccsrc/debug/trace.cc b/mindspore/ccsrc/debug/trace.cc index 22674cacd1..e12a7b1209 100644 --- a/mindspore/ccsrc/debug/trace.cc +++ b/mindspore/ccsrc/debug/trace.cc @@ -124,6 +124,8 @@ class AnalyzedFuncGraphExporter : public AnfExporter { void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); void OutputCNodes(std::ofstream &ofs, const std::vector &nodes, const FuncGraphPtr &func_graph); + void OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, int *idx, + std::map *const apply_map); private: std::string GetNodeType(const AnfNodePtr &nd) override; @@ -169,7 +171,7 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { } auto abs = ret->abstract(); if (abs == nullptr) { - return nullptr; + return "Undefined"; } auto dtype = abs->BuildType(); auto shape = abs->BuildShape(); @@ -247,6 +249,51 @@ AnalysisContextPtr AnalyzedFuncGraphExporter::ProcessFuncGraphCall(const CNodePt return ctx; } +void AnalyzedFuncGraphExporter::OutputCNode(std::ofstream &ofs, const CNodePtr &cnode, const FuncGraphPtr &func_graph, + int *idx, std::map *const apply_map) { + auto &inputs = cnode->inputs(); + std::string op_text = GetAnfNodeText(func_graph, inputs[0], *apply_map); + // non-return node + if (cnode != func_graph->get_return()) { + int apply_idx = (*idx)++; + (*apply_map)[cnode] = apply_idx; + std::string type_info = GetNodeType(cnode); + if (type_info == "Undefined") { + ofs << " %" << apply_idx << " = " << op_text << "("; + } else { + ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "("; + } + } else { + ofs << " " << op_text << "("; + } + + for (size_t i = 1; i < inputs.size(); ++i) { + if (i != 1) { + ofs << ", "; + } + AnfNodePtr arg = inputs[i]; + ofs << GetAnfNodeText(func_graph, arg, *apply_map); + } + ofs << ")"; + + // process function graph call + auto ctx = ProcessFuncGraphCall(cnode); + + // output comment + OutputStatementComment(ofs, cnode); + if (ctx != nullptr) { + ofs << " @ctx.addr=" << ctx.get(); + } + ofs << "\n"; + + if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { + ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" + << label_manage::Label(cnode->debug_info()) << "\n"; + } else { + ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n"; + } +} + void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vector &nodes, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { @@ -267,47 +314,7 @@ void AnalyzedFuncGraphExporter::OutputCNodes(std::ofstream &ofs, const std::vect } auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); - // non-return node - if (node != func_graph->get_return()) { - int apply_idx = idx++; - apply_map[node] = apply_idx; - std::string type_info = GetNodeType(node); - if (type_info == "Undefined") { - ofs << " %" << apply_idx << " = " << op_text << "("; - } else { - ofs << " %" << apply_idx << " : " << type_info << " = " << op_text << "("; - } - } else { - ofs << " " << op_text << "("; - } - - for (size_t i = 1; i < inputs.size(); ++i) { - if (i != 1) { - ofs << ", "; - } - AnfNodePtr arg = inputs[i]; - ofs << GetAnfNodeText(func_graph, arg, apply_map); - } - ofs << ")"; - - // process function graph call - auto ctx = ProcessFuncGraphCall(cnode); - - // output comment - OutputStatementComment(ofs, cnode); - if (ctx != nullptr) { - ofs << " @ctx.addr=" << ctx.get(); - } - ofs << "\n"; - - if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { - ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" - << label_manage::Label(cnode->debug_info()) << "\n"; - } else { - ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n"; - } + OutputCNode(ofs, cnode, func_graph, &idx, &apply_map); } } diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index 9591fef10d..7f33d4a3c7 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -76,44 +76,56 @@ bool CompareTensorScalarType(const TypeId &tensor_type, const size_t &t_type_num return true; } -void setMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, +void SetMaxType(TypeId *max_type_id, TypeId *max_type, size_t *max_type_number, const TypeId type_id, const TypeId type, const size_t type_number) { *max_type_id = type_id; *max_type = type; *max_type_number = type_number; } -TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indexs, - const std::set &write_indexs) { +bool GetTensorOrScalarTypeInfo(AbstractBasePtr arg_value, bool is_write, TypeId *arg_type_id, + TypeId *arg_type = nullptr) { + if (arg_value->isa()) { + if (is_write) { + arg_value = arg_value->cast()->ref_origin(); + } else { + arg_value = arg_value->cast()->ref(); + } + } + if (arg_value->isa()) { + auto tensor = arg_value->cast(); + auto tensor_type = tensor->element()->BuildType(); + MS_EXCEPTION_IF_NULL(tensor_type); + *arg_type_id = tensor_type->type_id(); + if (arg_type != nullptr) { + *arg_type = kObjectTypeTensorType; + } + return true; + } + if (arg_value->isa()) { + auto scalar = arg_value->cast(); + auto scalar_type = scalar->BuildType(); + MS_EXCEPTION_IF_NULL(scalar_type); + *arg_type_id = scalar_type->type_id(); + if (arg_type != nullptr) { + *arg_type = kObjectTypeNumber; + } + return true; + } + return false; +} + +TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::vector indices, + const std::set &write_indices) { TypeId max_type_id = kTypeUnknown; TypeId max_type = kTypeUnknown; size_t max_type_number = 0; bool has_int8 = false; - for (const auto &index : indexs) { + for (const auto &index : indices) { TypeId arg_type_id = kTypeUnknown; TypeId arg_type = kTypeUnknown; - AbstractBasePtr arg_value = args_spec_list[index]; - if (arg_value->isa()) { - auto is_write = (write_indexs.find(index) != write_indexs.end()); - if (is_write) { - arg_value = arg_value->cast()->ref_origin(); - } else { - arg_value = arg_value->cast()->ref(); - } - } - if (arg_value->isa()) { - auto tensor = arg_value->cast(); - auto tensor_type = tensor->element()->BuildType(); - MS_EXCEPTION_IF_NULL(tensor_type); - arg_type_id = tensor_type->type_id(); - arg_type = kObjectTypeTensorType; - } else if (arg_value->isa()) { - auto scalar = arg_value->cast(); - auto scalar_type = scalar->BuildType(); - MS_EXCEPTION_IF_NULL(scalar_type); - arg_type_id = scalar_type->type_id(); - arg_type = kObjectTypeNumber; - } else { + auto is_write = (write_indices.find(index) != write_indices.end()); + if (!GetTensorOrScalarTypeInfo(args_spec_list[index], is_write, &arg_type_id, &arg_type)) { continue; } auto it = type_map.find(arg_type_id); @@ -124,22 +136,22 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve has_int8 = true; } if (max_type_id == kTypeUnknown) { - setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); + SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); continue; } if (max_type == arg_type) { if (it->second > max_type_number) { - setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); + SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); } } else { if (arg_type == kObjectTypeTensorType) { if (CompareTensorScalarType(arg_type_id, it->second, max_type_id, max_type_number)) { - setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); + SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); } } else { if (!CompareTensorScalarType(max_type_id, max_type_number, arg_type_id, it->second)) { - setMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); + SetMaxType(&max_type_id, &max_type, &max_type_number, arg_type_id, arg_type, it->second); } } } @@ -154,28 +166,28 @@ TypeId GetMaxTypeId(const abstract::AbstractBasePtrList &args_spec_list, std::ve // Get the largest type of index in the same SignatureEnumDType of arguments. std::map GetMaxDtype(const std::vector &dtypes, const abstract::AbstractBasePtrList &args_spec_list, - const std::set &write_indexs) { + const std::set &write_indices) { // record index for signature.dtypes of the same type // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} - std::map> type_indexs; + std::map> type_indices; for (size_t i = 0; i < dtypes.size(); ++i) { - auto it = type_indexs.find(dtypes[i]); - if (it == type_indexs.end()) { - (void)type_indexs.insert(std::make_pair(dtypes[i], std::vector{i})); + auto it = type_indices.find(dtypes[i]); + if (it == type_indices.end()) { + (void)type_indices.insert(std::make_pair(dtypes[i], std::vector{i})); } else { it->second.push_back(i); } } std::map dst_type; - for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) { + for (auto it = type_indices.begin(); it != type_indices.end(); (void)++it) { auto type = it->first; - auto indexs = it->second; + auto indices = it->second; // If the number of arguments belonging to the same SignatureEnumDType is less than 2, skip it. - if (indexs.size() < 2) { + if (indices.size() < 2) { continue; } bool has_tensor = false; - for (const auto &index : indexs) { + for (const auto &index : indices) { AbstractBasePtr arg_value = args_spec_list[index]; if (arg_value->isa()) { arg_value = arg_value->cast()->ref(); @@ -189,7 +201,7 @@ std::map GetMaxDtype(const std::vector &signature, const abstract::AbstractBasePtrList &args_spec_list, const FuncGraphPtr &graph, - std::vector *const op_inputs, const std::set &write_indexs) { + std::vector *const op_inputs, const std::set &write_indices) { std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), [](const Signature &sig) { return sig.dtype; }); @@ -213,36 +225,19 @@ void DoAutoCast(const std::string &func_name, const std::vector &sign return; } // Stat the index of the arguments with the largest type in the same SignatureEnumDType. - std::map dst_type = GetMaxDtype(dtypes, args_spec_list, write_indexs); + std::map dst_type = GetMaxDtype(dtypes, args_spec_list, write_indices); // Identify which arg requires auto cast for (size_t i = 0; i < args_spec_list.size(); ++i) { auto it = dst_type.find(dtypes[i]); if (it == dst_type.end() || it->second == kTypeUnknown) { continue; } - auto rw_it = write_indexs.find(i); - auto is_write = (rw_it != write_indexs.end()); + auto rw_it = write_indices.find(i); + auto is_write = (rw_it != write_indices.end()); - AbstractBasePtr arg_value = args_spec_list[i]; - if (arg_value->isa()) { - if (is_write) { - arg_value = arg_value->cast()->ref_origin(); - } else { - arg_value = arg_value->cast()->ref(); - } - } TypeId arg_type_id = kTypeUnknown; - if (arg_value->isa()) { - auto tensor = arg_value->cast(); - auto tensor_type = tensor->element()->BuildType(); - MS_EXCEPTION_IF_NULL(tensor_type); - arg_type_id = tensor_type->type_id(); - } else if (arg_value->isa()) { - auto scalar = arg_value->cast(); - auto scalar_type = scalar->BuildType(); - MS_EXCEPTION_IF_NULL(scalar_type); - arg_type_id = scalar_type->type_id(); - } + AbstractBasePtr arg_value = args_spec_list[i]; + (void)GetTensorOrScalarTypeInfo(arg_value, is_write, &arg_type_id); auto it_map = type_map.find(arg_type_id); if (it_map == type_map.end()) { continue; @@ -279,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func } } std::vector op_inputs; - std::set write_indexs; + std::set write_indices; op_inputs.push_back(NewValueNode(function)); // Assume, the write input of op is always the first input. We check if any write op, // and add cast op on other inputs to keep the same type with assigned parameter. @@ -303,7 +298,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefValue), param}); } else if (sig == SignatureEnumRW::kRWWrite) { param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefOrigin), param}); - write_indexs.insert(i); + write_indices.insert(i); } // If sig is SignatureEnumRW::kRWRef, not do anything. } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { @@ -313,7 +308,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func } // process default ProcessDefault(func_name, args_spec_list, signature, has_var, &op_inputs); - DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indexs); + DoAutoCast(func_name, signature, args_spec_list, func_graph, &op_inputs, write_indices); return func_graph->NewCNode(op_inputs); } } // namespace diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc index 47f7c3d143..4b72c5b6d2 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/parse/data_converter.cc @@ -238,6 +238,31 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) { return bprop_graph; } +bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { + FuncGraphPtr func_graph = ConvertToFuncGraph(obj); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Parse resolve function error."; + return false; + } + // if the cell object has specified bprop, it has user-defined bprop function parse and record it + if (py::hasattr(obj, "bprop")) { + FuncGraphPtr bprop_graph = nullptr; + bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); + if (enable_bprop_debug) { + bprop_graph = ConvertToBpropCut(obj); + } else { + bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); + } + if (bprop_graph != nullptr) { + (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); + (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); + func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); + } + } + *data = func_graph; + return true; +} + bool ConvertOtherObj(py::object obj, ValuePtr *const data) { auto obj_type = data_converter::GetObjType(obj); MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; @@ -262,32 +287,12 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { // Create the namespace for common class instance // When the obj is Cell, default parse the 'construct' if (data_converter::IsCellInstance(obj)) { - FuncGraphPtr func_graph = ConvertToFuncGraph(obj); - if (func_graph == nullptr) { - MS_LOG(ERROR) << "Parse resolve function error."; - return false; - } - // if the cell object has specified bprop, it has user-defined bprop function parse and record it - if (py::hasattr(obj, "bprop")) { - FuncGraphPtr bprop_graph = nullptr; - bool enable_bprop_debug = py::cast(py::getattr(obj, "bprop_debug")); - if (enable_bprop_debug) { - bprop_graph = ConvertToBpropCut(obj); - } else { - bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD); - } - if (bprop_graph != nullptr) { - (void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph))); - (void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph))); - func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true); - } - } - *data = func_graph; - } else { - py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); - py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); - *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); + return ConvertCellObjToFuncGraph(obj, data); } + + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); + *data = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); return true; } MS_LOG(ERROR) << "Resolve type is invalid " << ((std::string)py::str(obj)); diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index b94340ad7b..7df5c6fea0 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -608,7 +608,7 @@ void Pipeline::Run() { MS_LOG(INFO) << "End"; } -void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list) { +void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) { std::size_t size = args.size(); for (std::size_t i = 0; i < size; i++) { diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h index c846b04888..633ff78c0b 100644 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ b/mindspore/ccsrc/pipeline/pipeline.h @@ -139,7 +139,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc const std::vector &types, const std::vector> &shapes, const std::vector &input_indexes, bool need_run); -void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *arg_list); +void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc index b7520176ec..0331fc9c07 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc @@ -464,6 +464,85 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector return ExecuteMultipleEvaluators(evaluators, out_conf, args_conf_list); } +void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator) { + auto fg_eval = evaluator->cast(); + if (fg_eval == nullptr) { + return; + } + auto fg = fg_eval->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto undetermined_fgs = fg->recursive_graphs(); + if (undetermined_fgs) { + auto fg_parent = fg->parent(); + MS_EXCEPTION_IF_NULL(fg_parent); + fg_parent->set_flags(kFuncGraphFlagUndetermined, true); + MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); + } +} + +EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector &evaluators, + const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, + const EvalTraceRevIter &it, bool *continue_flag) { + *continue_flag = false; + // Find latest entry function to handle nested recursion. + EvaluatorPtr latest_entry = eval; + auto latest_entry_iter = eval_trace_.rbegin(); + for (auto r_it = eval_trace_.rbegin(); *r_it != *it;) { + auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->first); + if (it_temp != evaluators.end()) { + latest_entry = *it_temp; + latest_entry_iter = r_it; + break; + } + latest_entry_iter = ++r_it; + } + if (latest_entry != eval) { + MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); + *continue_flag = true; + return latest_entry; + } + + bool has_undetermined = false; + // Check whether sub loop has untraced undetermined evaluator. + std::set> undetermined_evals; + for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { + undetermined_evals.insert(*r_it); + } + MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); + + for (auto u_eval : undetermined_evals) { + MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; + if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { + MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; + has_undetermined = true; + break; + } + } + if (has_undetermined == false) { + MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; + *continue_flag = true; + return latest_entry; + } + + return latest_entry; +} + +EvalResultPtr AnalysisEngine::ProcessEvalResults(const AbstractBasePtrList &out_specs) { + if (out_specs.size() == 0) { + MS_LOG(EXCEPTION) << "There is an endless loop for evaluator."; + } + + if (out_specs.size() == 1) { + MS_EXCEPTION_IF_NULL(out_specs[0]); + // If only one result derived, then broaden it to avoid wrong constant propagation. + return std::make_shared(out_specs[0]->Broaden(), std::make_shared()); + } + auto joined_spec = AbstractJoin(out_specs); + MS_EXCEPTION_IF_NULL(joined_spec); + MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); + return std::make_shared(joined_spec, std::make_shared()); +} + EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { @@ -479,18 +558,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorGetEvaluatedValue()->abstract(); }); for (auto eval : evaluators) { - auto fg_eval = eval->cast(); - if (fg_eval) { - auto fg = fg_eval->func_graph(); - MS_EXCEPTION_IF_NULL(fg); - auto undetermined_fgs = fg->recursive_graphs(); - if (undetermined_fgs) { - auto fg_parent = fg->parent(); - MS_EXCEPTION_IF_NULL(fg_parent); - fg_parent->set_flags(kFuncGraphFlagUndetermined, true); - MS_LOG(DEBUG) << "Set graph undetermined: " << fg_parent->ToString(); - } - } + SetUndeterminedFlag(eval); auto current_inf = std::make_pair(eval, args_spec_list); MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString(); @@ -510,40 +578,9 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorfirst); - if (it_temp != evaluators.end()) { - latest_entry = *it_temp; - latest_entry_iter = r_it; - break; - } - latest_entry_iter = ++r_it; - } - if (latest_entry != eval) { - MS_LOG(DEBUG) << "Continue Evaluator " << eval->ToString(); - continue; - } - - bool has_undetermined = false; - // Check whether sub loop has untraced undetermined evaluator. - std::set> undetermined_evals; - for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { - undetermined_evals.insert(*r_it); - } - MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); - for (auto u_eval : undetermined_evals) { - MS_LOG(DEBUG) << u_eval.first->ToString() << " check undetermined."; - if (!undetermined_evals.count(std::make_pair(multi_poss_[u_eval.first], args_spec_list))) { - MS_LOG(DEBUG) << u_eval.first->ToString() << " has undetermined."; - has_undetermined = true; - break; - } - } - if (has_undetermined == false) { - MS_LOG(DEBUG) << eval->ToString() << " has no undetermined."; + bool continue_flag = false; + auto latest_entry = HandleNestedRecursion(evaluators, eval, args_spec_list, it, &continue_flag); + if (continue_flag) { continue; } @@ -558,19 +595,8 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector(out_specs[0]->Broaden(), std::make_shared()); - } - auto joined_spec = AbstractJoin(out_specs); - MS_EXCEPTION_IF_NULL(joined_spec); - MS_LOG(DEBUG) << "Multiple evaluators joined: " << joined_spec->ToString(); - return std::make_shared(joined_spec, std::make_shared()); + return ProcessEvalResults(out_specs); } EvalResultPtr AnfNodeConfig::GetEvaluatedValue() { diff --git a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h index 1e7a52fda9..a0b7ee5478 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/static_analysis/static_analysis.h @@ -172,6 +172,8 @@ struct AnalysisResult { AnalysisContextPtr context; }; +using EvalTraceRevIter = std::list>::reverse_iterator; + class AnalysisEngine : public std::enable_shared_from_this { public: AnalysisEngine(const PrimEvaluatorMap &prim_evaluator_map, const FuncGraphManagerPtr &func_graph_manager) @@ -222,6 +224,12 @@ class AnalysisEngine : public std::enable_shared_from_this { std::unordered_map prim_py_evaluators_; private: + void SetUndeterminedFlag(const EvaluatorPtr &evaluator); + EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, + const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, + bool *continue_flag); + EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs); + const PrimEvaluatorMap &prim_constructors_; FuncGraphManagerPtr func_graph_manager_; std::unordered_map constructors_;