fix dataset_sink_mode=True issue with GPU

pull/13360/head
Islam Amin 4 years ago
parent c0f41deeae
commit 0fa5a443be

@ -132,7 +132,9 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
// set debugger // set debugger
void SetDebugger() { void SetDebugger() {
debugger_ = Debugger::GetInstance(); debugger_ = Debugger::GetInstance();
debugger_->Init(device_id_); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
} }
#endif #endif

@ -75,10 +75,9 @@ Debugger::Debugger()
CheckDebuggerEnabledParam(); CheckDebuggerEnabledParam();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
device_target_ = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET); std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
MS_LOG(INFO) << "Debugger got device_target: " << device_target_; MS_LOG(INFO) << "Debugger got device_target: " << device_target;
CheckDebuggerEnabledParam(); if (device_target == kCPUDevice) {
if (device_target_ == kCPUDevice) {
MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU."; MS_LOG(WARNING) << "Not enabling debugger. Debugger does not support CPU.";
} else if (CheckDebuggerEnabled()) { } else if (CheckDebuggerEnabled()) {
// configure partial memory reuse // configure partial memory reuse
@ -101,12 +100,14 @@ Debugger::Debugger()
} }
} }
void Debugger::Init(const uint32_t device_id) { void Debugger::Init(const uint32_t device_id, const std::string device_target) {
// access lock for public method // access lock for public method
std::lock_guard<std::mutex> a_lock(access_lock_); std::lock_guard<std::mutex> a_lock(access_lock_);
// save device_id // save device_id
MS_LOG(INFO) << "Debugger got device_id: " << device_id; MS_LOG(INFO) << "Debugger got device_id: " << device_id;
device_id_ = device_id; device_id_ = device_id;
MS_LOG(INFO) << "Debugger got device_target: " << device_target;
device_target_ = device_target;
version_ = "1.2.0"; version_ = "1.2.0";
} }

@ -67,7 +67,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// init // init
// only save device_id // only save device_id
void Init(const uint32_t device_id); void Init(const uint32_t device_id, const std::string device_target);
// reset debugger // reset debugger
void Reset(); void Reset();

Loading…
Cancel
Save