diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 736fc3afe1..8d64272178 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -577,7 +577,7 @@ void AscendSession::Execute(const std::shared_ptr &kernel_graph, bo void AscendSession::Dump(const std::shared_ptr &kernel_graph) const { MS_LOG(INFO) << "Start!"; MS_EXCEPTION_IF_NULL(kernel_graph); - E2eDumpUtil::DumpData(kernel_graph.get()); + E2eDumpUtil::DumpData(kernel_graph.get(), device_id_); MS_LOG(INFO) << "Finish!"; } diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 7da013a849..1d6ba5f9a9 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -335,7 +335,7 @@ void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_ void GPUSession::Dump(const std::shared_ptr &kernel_graph) const { if (debugger_->DebuggerBackendEnabled()) { MS_EXCEPTION_IF_NULL(kernel_graph); - E2eDumpUtil::DumpData(kernel_graph.get(), debugger_.get()); + E2eDumpUtil::DumpData(kernel_graph.get(), device_id_, debugger_.get()); } else { DumpJsonParser::GetInstance().UpdateDumpIter(); } diff --git a/mindspore/ccsrc/debug/data_dump/dump_json_parser.h b/mindspore/ccsrc/debug/data_dump/dump_json_parser.h index b8973ba595..b1b8c5ae91 100644 --- a/mindspore/ccsrc/debug/data_dump/dump_json_parser.h +++ b/mindspore/ccsrc/debug/data_dump/dump_json_parser.h @@ -47,7 +47,7 @@ class DumpJsonParser { uint32_t input_output() const { return input_output_; } uint32_t op_debug_mode() const { return op_debug_mode_; } bool trans_flag() const { return trans_flag_; } - uint32_t cur_dump_iter() { return cur_dump_iter_; } + uint32_t cur_dump_iter() const { return cur_dump_iter_; } void UpdateDumpIter() { ++cur_dump_iter_; } bool InputNeedDump() const; bool OutputNeedDump() const; diff --git a/mindspore/ccsrc/debug/data_dump/e2e_dump_util.cc b/mindspore/ccsrc/debug/data_dump/e2e_dump_util.cc index 664dab47d3..5c054a5ebd 100644 --- a/mindspore/ccsrc/debug/data_dump/e2e_dump_util.cc +++ b/mindspore/ccsrc/debug/data_dump/e2e_dump_util.cc @@ -192,7 +192,7 @@ void E2eDumpUtil::DumpParameters(const session::KernelGraph *graph, const std::s } } -bool E2eDumpUtil::DumpData(const session::KernelGraph *graph, Debugger *debugger) { +bool E2eDumpUtil::DumpData(const session::KernelGraph *graph, uint32_t device_id, Debugger *debugger) { MS_EXCEPTION_IF_NULL(graph); auto &dump_json_parser = DumpJsonParser::GetInstance(); dump_json_parser.UpdateDumpIter(); @@ -208,9 +208,6 @@ bool E2eDumpUtil::DumpData(const session::KernelGraph *graph, Debugger *debugger } } MS_LOG(INFO) << "Start e2e dump. Current iteration is " << dump_json_parser.cur_dump_iter(); - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - auto device_id = context->get_param(MS_CTX_DEVICE_ID); std::string net_name = dump_json_parser.net_name(); std::string iterator = std::to_string(dump_json_parser.cur_dump_iter()); diff --git a/mindspore/ccsrc/debug/data_dump/e2e_dump_util.h b/mindspore/ccsrc/debug/data_dump/e2e_dump_util.h index 4893b746a1..bc57cb6067 100644 --- a/mindspore/ccsrc/debug/data_dump/e2e_dump_util.h +++ b/mindspore/ccsrc/debug/data_dump/e2e_dump_util.h @@ -28,7 +28,7 @@ class E2eDumpUtil { public: E2eDumpUtil() = default; ~E2eDumpUtil() = default; - static bool DumpData(const session::KernelGraph *graph, Debugger *debugger = nullptr); + static bool DumpData(const session::KernelGraph *graph, uint32_t device_id, Debugger *debugger = nullptr); private: static void DumpOutput(const session::KernelGraph *graph, const std::string &dump_path, Debugger *debugger);