From 62ae6802dc6c03b87da11ea530667b9c1e44fb04 Mon Sep 17 00:00:00 2001 From: kswang Date: Thu, 19 Nov 2020 14:59:16 +0800 Subject: [PATCH] fix context null error --- .../ccsrc/backend/session/ascend_session.cc | 7 +++++ .../ccsrc/backend/session/ascend_session.h | 27 ++----------------- mindspore/ccsrc/backend/session/cpu_session.h | 2 +- .../ccsrc/backend/session/gpu_session.cc | 2 +- .../ccsrc/backend/session/session_basic.cc | 2 +- .../ccsrc/backend/session/session_basic.h | 2 +- .../device/ascend/ascend_kernel_runtime.cc | 17 +++++++----- .../device/ascend/ascend_kernel_runtime.h | 1 + .../device/ascend/ascend_memory_manager.cc | 1 - .../ccsrc/runtime/device/kernel_runtime.h | 1 + 10 files changed, 26 insertions(+), 36 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index e718d03372..75a0f604a8 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -125,6 +125,13 @@ void InsertMakeTupleForOutput(NotNull root_graph) { } } // namespace +void AscendSession::Init(uint32_t device_id) { + InitExecutor(kAscendDevice, device_id); + auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->CreateContext(); +} + GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { MS_LOG(INFO) << "Start"; // construct graph, if successfully, graph_sum_ + 1 diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index cce006acce..6d79ff9f7e 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -31,7 +31,6 @@ #include "backend/kernel_compiler/kernel.h" #include "backend/session/session_factory.h" #include "backend/session/ascend_control_parser.h" -#include "runtime/context.h" namespace mindspore { namespace session { @@ -40,28 +39,8 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, class AscendSession : public SessionBasic { public: AscendSession() { final_graph_id_ = kInvalidGraphId; } - ~AscendSession() { - if (rt_context_ != nullptr) { - auto ret = rtCtxDestroy(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]"; - } - rt_context_ = nullptr; - } - } - - void Init(uint32_t device_id) override { - InitDevice(kAscendDevice, device_id); - auto ret = rtCtxCreate(&rt_context_, 0, device_id); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; - } - ret = rtCtxSetCurrent(rt_context_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; - } - } - + ~AscendSession() = default; + void Init(uint32_t device_id) override; // get graph id of final graph GraphId GetFinalRunGraph() const override { return final_graph_id_; } @@ -136,8 +115,6 @@ class AscendSession : public SessionBasic { std::map, tensor::TensorPtr> initial_tenosrs_; // final_graph_id is used in every root graph has it's own session situation GraphId final_graph_id_; - // ascend runtime context - rtContext_t rt_context_{nullptr}; }; MS_REG_SESSION(kAscendDevice, AscendSession); } // namespace session diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index 3c52434547..20d28c00f2 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -29,7 +29,7 @@ class CPUSession : public SessionBasic { public: CPUSession() = default; ~CPUSession() override = default; - void Init(uint32_t device_id) override { InitDevice(kCPUDevice, device_id); } + void Init(uint32_t device_id) override { InitExecutor(kCPUDevice, device_id); } protected: void CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, VectorRef *, diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index b90787ff97..98045d5423 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -89,7 +89,7 @@ void GPUSession::Init(uint32_t device_id) { ms_context->set_param(MS_CTX_DEVICE_ID, device_id); MS_LOG(INFO) << "Set device id " << device_id << " for gpu session."; - InitDevice(kGPUDevice, device_id); + InitExecutor(kGPUDevice, device_id); } void GPUSession::SelectKernel(const std::shared_ptr &kernel_graph) const { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 2b85cab3d1..129da1d9ec 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -335,7 +335,7 @@ bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) { GraphId SessionBasic::graph_sum_ = 0; -void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id) { +void SessionBasic::InitExecutor(const std::string &device_name, uint32_t device_id) { device_id_ = device_id; context_ = std::make_shared(device_name, device_id); executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 9553d2fcd7..68dcf88163 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -64,7 +64,7 @@ class SessionBasic : public std::enable_shared_from_this { virtual void Init(uint32_t device_id) { device_id_ = device_id; } - void InitDevice(const std::string &device_name, uint32_t device_id); + void InitExecutor(const std::string &device_name, uint32_t device_id); virtual ~SessionBasic() { summary_callback_ = nullptr; } diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 524cb6296e..c17b912763 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -683,6 +683,16 @@ bool AscendKernelRuntime::SyncStream() { return true; } +void AscendKernelRuntime::CreateContext() { + if (rt_context_ == nullptr) { + auto ret = rtCtxCreate(&rt_context_, 0, device_id_); + if (ret != RT_ERROR_NONE) { + MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; + } + } + InnerSetContext(); +} + bool AscendKernelRuntime::InitDevice() { int device_count = 0; auto ret = rtGetDeviceCount(&device_count); @@ -713,12 +723,7 @@ bool AscendKernelRuntime::InitDevice() { MS_LOG(ERROR) << "Call rtCtxGetCurrent failed, ret[" << ret << "]"; return false; } - - ret = rtCtxCreate(&rt_context_, 0, device_id_); - if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast(ret) << "]"; - } - InnerSetContext(); + CreateContext(); ret = rtStreamCreate(&stream_, 0); if (ret != RT_ERROR_NONE) { MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]"; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index 259fb08c98..7c18bd9c9e 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -52,6 +52,7 @@ class AscendKernelRuntime : public KernelRuntime { void ClearGlobalIdleMem() override; bool SyncStream() override; void SetContext() override; + void CreateContext() override; protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc index 817fbd02c4..8b9267f820 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc @@ -40,7 +40,6 @@ void AscendMemoryManager::MallocDeviceMemory() { MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]"; } } - AscendMemoryPool::GetInstance().Init(device_mem_base_, device_mem_size_, dynamic_mem_offset_); } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index f941abf7fc..91387d3a68 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -75,6 +75,7 @@ class KernelRuntime { const std::vector &execution_order); virtual bool SyncStream() = 0; virtual void ClearGlobalIdleMem() {} + virtual void CreateContext() {} virtual void SetContext() {} uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { return mem_manager_->MallocMem(type, size, address);