gpu kernel runtime code review

pull/2919/head
limingqi107 5 years ago
parent 7b65c5483b
commit 44959f8874

File diff suppressed because it is too large Load Diff

@ -53,9 +53,9 @@ class GPUKernelRuntime : public KernelRuntime {
// The related functions and members for using dynamic memory pool. // The related functions and members for using dynamic memory pool.
void InitKernelRefCount(const session::KernelGraph *graph); void InitKernelRefCount(const session::KernelGraph *graph);
void InitKernelOutputAddress(const session::KernelGraph *graph); void InitKernelOutputAddress(const session::KernelGraph *graph);
void InitMemorySwapInfo(const session::KernelGraph *graph);
void ClearKernelOutputAddress(const session::KernelGraph *graph); void ClearKernelOutputAddress(const session::KernelGraph *graph);
bool LaunchKernelDynamic(const session::KernelGraph *graph); bool LaunchKernelDynamic(const session::KernelGraph *graph);
bool AddMemSwapTask(const AnfNodePtr &kernel);
bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size); bool AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size);
void *AttemptMallocMem(size_t size); void *AttemptMallocMem(size_t size);
bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel, bool AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
@ -74,8 +74,14 @@ class GPUKernelRuntime : public KernelRuntime {
std::vector<size_t> size_list); std::vector<size_t> size_list);
void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces,
uint32_t graph_id); uint32_t graph_id);
bool AddMemorySwapTask(const AnfNodePtr &kernel);
bool UpdateMemorySwapInfo(const session::KernelGraph *graph);
bool UpdateMemorySwapTask(const AnfNodePtr &kernel);
void UpdateHostSwapQueue(const DeviceAddressPtr device_address);
void UpdateDeviceSwapQueue();
void ClearSwapQueue();
std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_; std::unordered_map<uint32_t, MemReuseUtilPtr> mem_reuse_util_map_;
std::unordered_map<void *, MemSwapManagerPtr> mem_swap_map_; std::unordered_map<uint32_t, MemSwapManagerPtr> mem_swap_map_;
MemSwapManagerPtr mem_swap_manager_{nullptr}; MemSwapManagerPtr mem_swap_manager_{nullptr};
}; };
MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime);

@ -187,8 +187,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
GetSummaryNodes(graph.get()); GetSummaryNodes(graph.get());
// Remove NoOp from execution graph // Remove NoOp from execution graph
opt::RemoveNopNode(graph.get()); opt::RemoveNopNode(graph.get());
// Alloc memory, including static memory and dynamic memory // Set graph manager.
AllocateMemory(graph.get());
MS_EXCEPTION_IF_NULL(context_); MS_EXCEPTION_IF_NULL(context_);
FuncGraphManagerPtr manager = MakeManager({graph}); FuncGraphManagerPtr manager = MakeManager({graph});
context_->AddManager(manager); context_->AddManager(manager);
@ -196,6 +195,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
manager->AddFuncGraph(graph); manager->AddFuncGraph(graph);
graph->set_manager(manager); graph->set_manager(manager);
} }
// Alloc memory, including static memory and dynamic memory
AllocateMemory(graph.get());
return graph_id; return graph_id;
} }

Loading…
Cancel
Save