diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 0891e211a0..98cd2f4b2f 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -34,6 +34,7 @@ #include "utils/utils.h" #include "debug/trace.h" #include "utils/context/ms_context.h" +#include "operator/ops.h" namespace mindspore { // max number of elements in sequence @@ -69,7 +70,7 @@ py::object load_obj(const std::string& path) { // ============================================= MindSpore IR Exporter ============================================= -std::string GetNodeType(const AnfNodePtr& nd) { +std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast(nd->Shape()); TypePtr type = dyn_cast(nd->Type()); std::ostringstream oss; @@ -102,7 +103,7 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& FuncGraphPtr fg = func_graph; while (fg != nullptr) { if (exported.find(fg) == exported.end()) { - if (!export_used_) { + if (!check_integrity_) { break; } MS_LOG(EXCEPTION) << "Can not find func graph '" << fg->DumpText() << "." << fg->debug_info()->get_id() << "'"; @@ -255,15 +256,15 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { } // output primitive attributes - auto attrs = prim->attrs(); - if (attrs.size() > 0) { - oss << "["; - int i = 0; - for (auto& attr : attrs) { - oss << (i > 0 ? ", " : "") << attr.first << "=" << attr.second->DumpText(); - i++; + oss << prim->GetAttrsText(); + + if (prim->isa()) { + auto do_signature = dyn_cast(prim); + auto& func = do_signature->function(); + if (func->isa()) { + auto sig_prim = dyn_cast(func); + oss << sig_prim->GetAttrsText(); } - oss << "]"; } return oss.str(); @@ -351,7 +352,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { std::ostringstream oss; - if (export_used_) { + if (check_integrity_) { MS_LOG(EXCEPTION) << "Need to process type: " << value->type_name() << ", dump text: " << value->DumpText(); } oss << value->type_name() << "[" << value->DumpText() << "]"; @@ -420,7 +421,7 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An } oss << "%" << iter->second; } else if (node->isa()) { - oss << "%para" << GetParamIndex(func_graph, node, export_used_); + oss << "%para" << GetParamIndex(func_graph, node, check_integrity_); } else if (IsValueNode(node)) { FuncGraphPtr fg = GetValueNode(node); oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id(); diff --git a/mindspore/ccsrc/debug/anf_ir_utils.h b/mindspore/ccsrc/debug/anf_ir_utils.h index 5b9ac9d2f0..5342c1ab96 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.h +++ b/mindspore/ccsrc/debug/anf_ir_utils.h @@ -64,17 +64,18 @@ struct ParamPtrHasher { class AnfExporter { public: - explicit AnfExporter(const std::string& id, bool export_used = true) - : param_index(-1), id_(id), export_used_(export_used) { + explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) + : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { func_graph_set.clear(); exported.clear(); } - ~AnfExporter() {} + virtual ~AnfExporter() {} void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); void ExportFuncGraph(const std::string& filename, const std::vector& graphs); - private: + protected: + virtual std::string GetNodeType(const AnfNodePtr& nd); int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); int GetParamIndexFromExported(const AnfNodePtr& param); std::string DumpObject(const py::object& obj, const std::string& category) const; @@ -101,8 +102,10 @@ class AnfExporter { OrderedSet func_graph_set{}; OrderedMap> exported; std::string id_; - bool export_used_ = true; // whether export function graphs used in current exporting function graph + bool export_used_ = true; // whether export function graphs used in current exporting function graph + bool check_integrity_ = false; // whether check integrity or not, when dumping ir for loading, must set it to true TaggedNodeMap tagged_cnodes_; + abstract::AnfNodeConfigPtr node_cfg_ = nullptr; }; void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); @@ -115,7 +118,6 @@ std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); -std::string GetNodeType(const AnfNodePtr& nd); } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ diff --git a/mindspore/ccsrc/debug/trace.cc b/mindspore/ccsrc/debug/trace.cc index 4b0f4c4fb3..7ce13052c5 100644 --- a/mindspore/ccsrc/debug/trace.cc +++ b/mindspore/ccsrc/debug/trace.cc @@ -17,6 +17,7 @@ #include "debug/trace.h" #include +#include #include #include #include @@ -194,37 +195,116 @@ void TraceGraphInfer() { MS_LOG(INFO) << "\n*************************************************************************************"; } -void OutputAnalysisGraphInfo() { - MS_LOG(INFO) << "Output analysis graph begin"; - std::unordered_map index_map; - std::vector tagged_graphs; +class AnalyzedFuncGraphExporter : public AnfExporter { + public: + AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} + ~AnalyzedFuncGraphExporter() override = default; + void ExportFuncGraph(const std::string& filename, const std::vector& node_cfgs); + + private: + std::string GetNodeType(const AnfNodePtr& nd) override; +}; + +std::unordered_map CalcTaggedFuncGraphs() { + std::unordered_map tagged_func_graphs; auto& list = GetCNodeDebugStack(); for (size_t i = 0; i < list.size(); ++i) { - auto& node_cfg = list[i]; + auto node_cfg = list[i]; auto fg = node_cfg->context()->func_graph(); auto node = node_cfg->node(); - auto idx = tagged_graphs.size(); - std::pair item(fg, idx); - if (index_map.insert(item).second) { - tagged_graphs.emplace_back(TaggedGraph(fg, TaggedNodeMap())); + tagged_func_graphs[fg][node] = i; + } + return tagged_func_graphs; +} + +void OutputAnalyzedGraphWithType() { + AnalyzedFuncGraphExporter exporter; + exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); +} + +std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { + if (node_cfg_ == nullptr) { + return AnfExporter::GetNodeType(node); + } + auto ctx = node_cfg_->context(); + auto engine = node_cfg_->engine(); + auto cfg = engine->MakeConfig(node, ctx); + auto abs = engine->cache().GetValue(cfg); + + if (abs == nullptr) { + return "Undefined"; + } + auto dtype = abs->BuildType(); + auto shape = abs->BuildShape(); + std::ostringstream oss; + if (dtype != nullptr && abs->isa() && shape != nullptr) { + oss << dtype->DumpText() << shape->DumpText(); + } else if (dtype != nullptr) { + oss << dtype->DumpText(); + } else { + oss << "Undefined"; + } + return oss.str(); +} + +void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, + const std::vector& node_cfgs) { + if (node_cfgs.empty()) { + MS_LOG(DEBUG) << "Node configs is empty"; + return; + } + + std::ofstream ofs(filename); + if (!ofs.is_open()) { + MS_LOG(ERROR) << "Open file '" << filename << "' failed!"; + return; + } + + param_index = 1; + auto tagged_func_graphs = CalcTaggedFuncGraphs(); + + // first output grapn on the analysis stack + for (const auto& node_cfg : node_cfgs) { + auto fg = node_cfg->context()->func_graph(); + // the graph is already output, skip it + if (exported.find(fg) != exported.end()) { + continue; } - tagged_graphs[index_map[fg]].second[node] = i; + // set node_cfg info for getting type + node_cfg_ = node_cfg; + tagged_cnodes_ = tagged_func_graphs[fg]; + ExportOneFuncGraph(ofs, fg); + ofs << "\n\n"; + } + + node_cfg_ = nullptr; + tagged_cnodes_.clear(); + + // print seperator between function graphs on analyzed graph call stack and others + ofs << "#===============================================================================\n\n\n"; + + // second output other graphs + while (!func_graph_set.empty()) { + FuncGraphPtr fg = *func_graph_set.begin(); + ExportOneFuncGraph(ofs, fg); + ofs << "\n\n"; + (void)func_graph_set.erase(fg); } + ofs << "# num of total funcgraphs: " << exported.size(); - ExportIR("analyze_fail.dat", tagged_graphs); - MS_LOG(INFO) << "Output analysis graph *end*"; + ofs.close(); } void GetInferStackInfo(std::ostringstream& oss) { MS_LOG(INFO) << "Get graph analysis information begin"; - auto& stack = GetCNodeDebugStack(); + auto stack = GetCNodeDebugStack(); if (stack.empty()) { MS_LOG(INFO) << "Length of analysis information stack is empty."; return; } - OutputAnalysisGraphInfo(); + OutputAnalyzedGraphWithType(); oss << "\nThe function call stack:\n"; int index = 0; diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index d62553ef60..a576c1e76b 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -106,6 +106,27 @@ void Primitive::set_signatures( } } +std::string Primitive::GetAttrsText() const { + if (attrs_.empty()) { + return ""; + } + + std::ostringstream oss; + oss << "["; + bool is_first = true; + for (auto& attr : attrs_) { + if (is_first) { + is_first = false; + } else { + oss << ", "; + } + oss << attr.first << "=" << attr.second->DumpText(); + } + oss << "]"; + + return oss.str(); +} + py::function PrimitivePy::GetBpropFunction() { static const char* const get_bprop_func_name = "get_bprop"; if (py::hasattr(python_obj_, get_bprop_func_name)) { diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index 8a60412e44..7dd37eb15f 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -102,6 +102,7 @@ class Primitive : public Named { PrimType prim_type() const { return prim_type_; } std::string instance_name() const { return instance_name_; } + std::string GetAttrsText() const; bool operator==(const Value& other) const override; bool operator==(const Primitive& other) const; ~Primitive() override = default; diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index 984f24b326..d9dd9e5e99 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -22,6 +22,7 @@ #include "operator/ops.h" #include "pipeline/static_analysis/prim.h" #include "pipeline/static_analysis/abstract_function.h" +#include "debug/trace.h" namespace mindspore { using Shape = abstract::Shape; @@ -124,6 +125,7 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { AbstractBasePtrList args_spec_list = {tuple_tensor, start_index}; try { + trace::ClearTraceStack(); engine_->Run(tupleSliceGraphPtr, args_spec_list); FAIL() << "Excepted exception :Args type is wrong"; } catch (std::runtime_error const &err) {