!8696 fix context null error

From: @kisnwang
Reviewed-by: 
Signed-off-by:
pull/8696/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 270c156219

@ -125,6 +125,13 @@ void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) {
}
} // namespace
void AscendSession::Init(uint32_t device_id) {
InitExecutor(kAscendDevice, device_id);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
MS_EXCEPTION_IF_NULL(runtime_instance);
runtime_instance->CreateContext();
}
GraphId AscendSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) {
MS_LOG(INFO) << "Start";
// construct graph, if successfully, graph_sum_ + 1

@ -31,7 +31,6 @@
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/session_factory.h"
#include "backend/session/ascend_control_parser.h"
#include "runtime/context.h"
namespace mindspore {
namespace session {
@ -40,28 +39,8 @@ enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2,
class AscendSession : public SessionBasic {
public:
AscendSession() { final_graph_id_ = kInvalidGraphId; }
~AscendSession() {
if (rt_context_ != nullptr) {
auto ret = rtCtxDestroy(rt_context_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]";
}
rt_context_ = nullptr;
}
}
void Init(uint32_t device_id) override {
InitDevice(kAscendDevice, device_id);
auto ret = rtCtxCreate(&rt_context_, 0, device_id);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
}
ret = rtCtxSetCurrent(rt_context_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
}
}
~AscendSession() = default;
void Init(uint32_t device_id) override;
// get graph id of final graph
GraphId GetFinalRunGraph() const override { return final_graph_id_; }
@ -136,8 +115,6 @@ class AscendSession : public SessionBasic {
std::map<std::pair<GraphId, size_t>, tensor::TensorPtr> initial_tenosrs_;
// final_graph_id is used in every root graph has it's own session situation
GraphId final_graph_id_;
// ascend runtime context
rtContext_t rt_context_{nullptr};
};
MS_REG_SESSION(kAscendDevice, AscendSession);
} // namespace session

@ -29,7 +29,7 @@ class CPUSession : public SessionBasic {
public:
CPUSession() = default;
~CPUSession() override = default;
void Init(uint32_t device_id) override { InitDevice(kCPUDevice, device_id); }
void Init(uint32_t device_id) override { InitExecutor(kCPUDevice, device_id); }
protected:
void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *,

@ -89,7 +89,7 @@ void GPUSession::Init(uint32_t device_id) {
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id);
MS_LOG(INFO) << "Set device id " << device_id << " for gpu session.";
InitDevice(kGPUDevice, device_id);
InitExecutor(kGPUDevice, device_id);
}
void GPUSession::SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {

@ -337,7 +337,7 @@ bool IgnoreCreateParameterForMakeTuple(const AnfNodePtr &node) {
GraphId SessionBasic::graph_sum_ = 0;
void SessionBasic::InitDevice(const std::string &device_name, uint32_t device_id) {
void SessionBasic::InitExecutor(const std::string &device_name, uint32_t device_id) {
device_id_ = device_id;
context_ = std::make_shared<Context>(device_name, device_id);
executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);

@ -67,7 +67,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
virtual void Init(uint32_t device_id) { device_id_ = device_id; }
void InitDevice(const std::string &device_name, uint32_t device_id);
void InitExecutor(const std::string &device_name, uint32_t device_id);
virtual ~SessionBasic() { summary_callback_ = nullptr; }

@ -683,6 +683,16 @@ bool AscendKernelRuntime::SyncStream() {
return true;
}
void AscendKernelRuntime::CreateContext() {
if (rt_context_ == nullptr) {
auto ret = rtCtxCreate(&rt_context_, 0, device_id_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
}
}
InnerSetContext();
}
bool AscendKernelRuntime::InitDevice() {
int device_count = 0;
auto ret = rtGetDeviceCount(&device_count);
@ -713,12 +723,7 @@ bool AscendKernelRuntime::InitDevice() {
MS_LOG(ERROR) << "Call rtCtxGetCurrent failed, ret[" << ret << "]";
return false;
}
ret = rtCtxCreate(&rt_context_, 0, device_id_);
if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
}
InnerSetContext();
CreateContext();
ret = rtStreamCreate(&stream_, 0);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";

@ -52,6 +52,7 @@ class AscendKernelRuntime : public KernelRuntime {
void ClearGlobalIdleMem() override;
bool SyncStream() override;
void SetContext() override;
void CreateContext() override;
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,

@ -40,7 +40,6 @@ void AscendMemoryManager::MallocDeviceMemory() {
MS_EXCEPTION(DeviceProcessError) << "rtMalloc mem size[" << device_mem_size_ << "] fail, ret[" << ret << "]";
}
}
AscendMemoryPool::GetInstance().Init(device_mem_base_, device_mem_size_, dynamic_mem_offset_);
}

@ -75,6 +75,7 @@ class KernelRuntime {
const std::vector<CNodePtr> &execution_order);
virtual bool SyncStream() = 0;
virtual void ClearGlobalIdleMem() {}
virtual void CreateContext() {}
virtual void SetContext() {}
uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
return mem_manager_->MallocMem(type, size, address);

Loading…
Cancel
Save