!7176 FIx gpu dump wrong device_id

Merge pull request !7176 from caifubi/gpu_dump
pull/7176/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 87df41f194

@ -577,7 +577,7 @@ void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bo
void AscendSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
MS_EXCEPTION_IF_NULL(kernel_graph);
E2eDumpUtil::DumpData(kernel_graph.get());
E2eDumpUtil::DumpData(kernel_graph.get(), device_id_);
MS_LOG(INFO) << "Finish!";
}

@ -335,7 +335,7 @@ void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_
void GPUSession::Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const {
if (debugger_->DebuggerBackendEnabled()) {
MS_EXCEPTION_IF_NULL(kernel_graph);
E2eDumpUtil::DumpData(kernel_graph.get(), debugger_.get());
E2eDumpUtil::DumpData(kernel_graph.get(), device_id_, debugger_.get());
} else {
DumpJsonParser::GetInstance().UpdateDumpIter();
}

@ -47,7 +47,7 @@ class DumpJsonParser {
uint32_t input_output() const { return input_output_; }
uint32_t op_debug_mode() const { return op_debug_mode_; }
bool trans_flag() const { return trans_flag_; }
uint32_t cur_dump_iter() { return cur_dump_iter_; }
uint32_t cur_dump_iter() const { return cur_dump_iter_; }
void UpdateDumpIter() { ++cur_dump_iter_; }
bool InputNeedDump() const;
bool OutputNeedDump() const;

@ -192,7 +192,7 @@ void E2eDumpUtil::DumpParameters(const session::KernelGraph *graph, const std::s
}
}
bool E2eDumpUtil::DumpData(const session::KernelGraph *graph, Debugger *debugger) {
bool E2eDumpUtil::DumpData(const session::KernelGraph *graph, uint32_t device_id, Debugger *debugger) {
MS_EXCEPTION_IF_NULL(graph);
auto &dump_json_parser = DumpJsonParser::GetInstance();
dump_json_parser.UpdateDumpIter();
@ -208,9 +208,6 @@ bool E2eDumpUtil::DumpData(const session::KernelGraph *graph, Debugger *debugger
}
}
MS_LOG(INFO) << "Start e2e dump. Current iteration is " << dump_json_parser.cur_dump_iter();
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
auto device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
std::string net_name = dump_json_parser.net_name();
std::string iterator = std::to_string(dump_json_parser.cur_dump_iter());

@ -28,7 +28,7 @@ class E2eDumpUtil {
public:
E2eDumpUtil() = default;
~E2eDumpUtil() = default;
static bool DumpData(const session::KernelGraph *graph, Debugger *debugger = nullptr);
static bool DumpData(const session::KernelGraph *graph, uint32_t device_id, Debugger *debugger = nullptr);
private:
static void DumpOutput(const session::KernelGraph *graph, const std::string &dump_path, Debugger *debugger);

Loading…
Cancel
Save