!13344 add DeviceContext module

From: @zyli2020
Reviewed-by: 
Signed-off-by:
pull/13344/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5fd3d140b6

@ -208,6 +208,7 @@ set(SUB_COMP
backend/kernel_compiler
backend/session
runtime/device
runtime/hardware
runtime/hccl_adapter
frontend/optimizer
frontend/parallel

@ -196,7 +196,8 @@ void GPUSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph)
}
void GPUSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
device::gpu::GpuBuild(kernel_graph);
auto kernels = kernel_graph->execution_order();
device::gpu::CreateGPUKernel(kernels);
}
void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const {

@ -31,6 +31,7 @@ namespace cpu {
class CPUSimpleMemPlan;
class CPUMemoryManager;
class CPUKernelRuntime;
class CPUDeviceContext;
} // namespace cpu
namespace ascend {
class AscendKernelRuntime;
@ -96,6 +97,7 @@ class DeviceAddress : public mindspore::DeviceSync {
friend class mindspore::device::cpu::CPUSimpleMemPlan;
friend class mindspore::device::cpu::CPUMemoryManager;
friend class mindspore::device::cpu::CPUKernelRuntime;
friend class mindspore::device::cpu::CPUDeviceContext;
friend class mindspore::device::gpu::GPUKernelRuntime;
friend class mindspore::device::gpu::GPUMemoryManager;
friend class mindspore::device::ascend::AscendKernelRuntime;

@ -27,12 +27,10 @@
namespace mindspore {
namespace device {
namespace gpu {
void GpuBuild(const KernelGraphPtr &kernel_graph) {
void CreateGPUKernel(const std::vector<CNodePtr> &kernels) {
kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance();
MS_EXCEPTION_IF_NULL(bin_map);
MS_EXCEPTION_IF_NULL(kernel_graph);
bool already_check_nvcc = false;
auto kernels = kernel_graph->execution_order();
std::vector<AnfNodePtr> akg_nodes;
for (const auto &kernel : kernels) {
std::string kernel_name = session::AnfRuntimeAlgorithm::GetCNodeName(kernel);

@ -13,16 +13,17 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPUKERNELBUILD_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPUKERNELBUILD_H_
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_BUILD_H_
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_BUILD_H_
#include <vector>
#include <memory>
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace device {
namespace gpu {
void GpuBuild(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void CreateGPUKernel(const std::vector<CNodePtr> &kernels);
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPUKERNELBUILD_H_
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_GPU_GPU_KERNEL_BUILD_H_

@ -53,6 +53,7 @@ void GPULaunchkernel::KernelSelect(std::shared_ptr<session::KernelGraph> kernel_
}
void GPULaunchkernel::KernelBuild(std::shared_ptr<session::KernelGraph> kernel_graph) {
device::gpu::GpuBuild(kernel_graph);
auto kernels = kernel_graph->execution_order();
device::gpu::CreateGPUKernel(kernels);
}
} // namespace mindspore::device::gpu

@ -0,0 +1,17 @@
file(GLOB_RECURSE HARDWARE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"device_context_manager.cc")
if(ENABLE_GPU)
file(GLOB_RECURSE HARDWARE_GPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"gpu/gpu_device_context.cc")
endif()
if(ENABLE_CPU)
file(GLOB_RECURSE HARDWARE_CPU_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"cpu/cpu_device_context.cc")
endif()
set_property(SOURCE ${HARDWARE_SRC_LIST} ${HARDWARE_D_SRC_LIST} ${HARDWARE_GPU_SRC_LIST} ${HARDWARE_CPU_SRC_LIST}
PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_DEVICE)
add_library(_mindspore_runtime_hardware_obj OBJECT ${HARDWARE_SRC_LIST} ${HARDWARE_D_SRC_LIST}
${HARDWARE_GPU_SRC_LIST} ${HARDWARE_CPU_SRC_LIST})

@ -0,0 +1,78 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/hardware/cpu/cpu_device_context.h"
#include <string>
#include "runtime/device/cpu/cpu_device_address.h"
#include "runtime/device/cpu/cpu_memory_manager.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "runtime/device/cpu/kernel_select_cpu.h"
#include "utils/trace_base.h"
namespace mindspore {
namespace device {
namespace cpu {
bool CPUDeviceContext::Initialize() {
if (initialized_) {
return true;
}
mem_manager_ = std::make_shared<CPUMemoryManager>();
MS_EXCEPTION_IF_NULL(mem_manager_);
initialized_ = true;
return true;
}
bool CPUDeviceContext::AllocateMemory(const DeviceAddressPtr &address, size_t size) const {
address->ptr_ = static_cast<CPUMemoryManager *>(mem_manager_.get())->StaticMemMalloc(size);
return true;
}
void CPUDeviceContext::FreeMemory(const DeviceAddressPtr &address) const {
static_cast<CPUMemoryManager *>(mem_manager_.get())->MemFree(address->ptr_);
address->ptr_ = nullptr;
}
void CPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
for (const auto &node : nodes) {
SetKernelInfo(node);
}
}
void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
for (const auto &node : nodes) {
MS_EXCEPTION_IF_NULL(node);
std::string kernel_name = AnfAlgo::GetCNodeName(node);
std::shared_ptr<kernel::CPUKernel> cpu_kernel = kernel::CPUKernelFactory::GetInstance().Create(kernel_name, node);
if (!cpu_kernel) {
MS_LOG(EXCEPTION) << "Build cpu operator[" << node->fullname_with_scope() << "] failed";
}
cpu_kernel->Init(node);
AnfAlgo::SetKernelMod(cpu_kernel, node.get());
}
}
bool CPUDeviceContext::LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const {
MS_EXCEPTION_IF_NULL(kernel_mod);
return kernel_mod->Launch(inputs, workspace, outputs, nullptr);
}
MS_REGISTER_DEVICE(kCPUDevice, CPUDeviceContext);
} // namespace cpu
} // namespace device
} // namespace mindspore

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_CPU_DEVICE_CONTEXT_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_CPU_DEVICE_CONTEXT_H_
#include <vector>
#include <memory>
#include "runtime/hardware/device_context.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/device/memory_manager.h"
namespace mindspore {
namespace device {
namespace cpu {
class CPUDeviceContext : public DeviceContext {
public:
explicit CPUDeviceContext(const DeviceContextKey &device_context_key)
: DeviceContext(device_context_key), mem_manager_(nullptr), initialized_(false) {}
~CPUDeviceContext() override = default;
bool Initialize() override;
bool AllocateMemory(const DeviceAddressPtr &address, size_t size) const override;
void FreeMemory(const DeviceAddressPtr &address) const override;
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override;
private:
DISABLE_COPY_AND_ASSIGN(CPUDeviceContext);
std::shared_ptr<MemoryManager> mem_manager_;
bool initialized_;
};
} // namespace cpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_CPU_CPU_DEVICE_CONTEXT_H_

@ -0,0 +1,94 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_
#include <string>
#include <vector>
#include <memory>
#include "runtime/device/device_address.h"
#include "backend/session/kernel_graph.h"
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace device {
using mindspore::kernel::AddressPtr;
using mindspore::kernel::KernelMod;
struct DeviceContextKey {
// device type name, such as 'GPU' 'Ascend' 'CPU'.
std::string device_name_;
uint32_t device_id_{0};
// Use the result of ToString() as key to look up DeviceContext
// in cache map which maintains created DeviceContext objects.
std::string ToString() const { return device_name_ + "_" + std::to_string(device_id_); }
};
// DeviceContext is unified interface of interaction with device.
class DeviceContext {
public:
explicit DeviceContext(const DeviceContextKey &device_context_key) : device_context_key_(device_context_key) {}
virtual ~DeviceContext() = default;
// Initialize the device context and return success or not.
virtual bool Initialize() = 0;
// Destroy device context and release device resource.
virtual void Destroy() {}
// Relevant function to allocate and free device memory.
virtual bool AllocateMemory(const DeviceAddressPtr &address, size_t size) const = 0;
virtual void FreeMemory(const DeviceAddressPtr &address) const = 0;
// Allocate continuous device memory end to end into 'addr_list'.
// Communication operators may need continuous memory for input and output
// to optimize the communication performance.
virtual bool AllocateContinuousMemory(const DeviceAddressPtrList &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const {
return true;
}
// Optimize the kernel graph according to different devices.
virtual void OptimizeGraph(const KernelGraphPtr &graph) const {}
// Select the matching backend kernels according to the data type and format of input and output for all
// execution operators, and set final device data type and format information for backend kernels, device
// data type and format which replace original data type and format will use for executing kernels.
virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {}
// Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
// 'KernelMod' is real executive object of kernel.
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {}
// Launch a kernel via 'KernelMod' of the kernel.
virtual bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const = 0;
// Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously,
// using 'SyncStream' to block thread and wait for completing all tasks in stream.
// Devices that do not need stream could ignore the implementation of this function.
virtual bool SyncStream(size_t stream_id = 0) { return true; }
protected:
DeviceContextKey device_context_key_;
};
using DeviceContextPtr = std::shared_ptr<DeviceContext>;
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/hardware/device_context_manager.h"
namespace mindspore {
namespace device {
void DeviceContextManager::Register(const std::string &device_name, DeviceContextCreator &&device_context_creator) {
if (device_context_creators_.find(device_name) == device_context_creators_.end()) {
(void)device_context_creators_.emplace(device_name, device_context_creator);
}
}
void DeviceContextManager::ClearDeviceContexts() {
std::lock_guard<std::mutex> guard(lock_);
for (auto &iter : device_contexts_) {
MS_LOG(INFO) << "Release device " << iter.first;
MS_EXCEPTION_IF_NULL(iter.second);
iter.second->Destroy();
}
device_contexts_.clear();
}
DeviceContext *DeviceContextManager::GetDeviceContext(const DeviceContextKey &device_context_key) {
std::string device_context_key_str = device_context_key.ToString();
std::lock_guard<std::mutex> guard(lock_);
auto device_context_iter = device_contexts_.find(device_context_key_str);
if (device_context_iter != device_contexts_.end()) {
return device_context_iter->second.get();
}
std::shared_ptr<DeviceContext> device_context;
auto creator_iter = device_context_creators_.find(device_context_key.device_name_);
if (creator_iter != device_context_creators_.end()) {
device_context = (creator_iter->second)(device_context_key);
MS_EXCEPTION_IF_NULL(device_context);
device_contexts_[device_context_key_str] = device_context;
} else {
MS_LOG(EXCEPTION) << "There is no device context creator for " << device_context_key.device_name_
<< " with device id " << device_context_key.device_id_;
}
return device_context.get();
}
} // namespace device
} // namespace mindspore

@ -0,0 +1,69 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_
#include <map>
#include <string>
#include <memory>
#include <utility>
#include <functional>
#include <mutex>
#include <vector>
#include "runtime/hardware/device_context.h"
namespace mindspore {
namespace device {
using DeviceContextCreator = std::function<std::shared_ptr<DeviceContext>(const DeviceContextKey &)>;
class DeviceContextManager {
public:
static DeviceContextManager &GetInstance() {
static DeviceContextManager instance;
return instance;
}
void Register(const std::string &device_name, DeviceContextCreator &&device_context_creator);
DeviceContext *GetDeviceContext(const DeviceContextKey &device_info);
void ClearDeviceContexts();
private:
DeviceContextManager() = default;
~DeviceContextManager() = default;
DISABLE_COPY_AND_ASSIGN(DeviceContextManager);
// The string converted from DeviceContextKey -> DeviceContextPtr.
std::map<std::string, DeviceContextPtr> device_contexts_;
// The name of device -> DeviceContextCreator.
std::map<std::string, DeviceContextCreator> device_context_creators_;
std::mutex lock_;
};
class DeviceContextRegister {
public:
DeviceContextRegister(const std::string &device_name, DeviceContextCreator &&runtime_creator) {
DeviceContextManager::GetInstance().Register(device_name, std::move(runtime_creator));
}
~DeviceContextRegister() = default;
};
#define MS_REGISTER_DEVICE(DEVICE_NAME, DEVICE_CONTEXT_CLASS) \
static const DeviceContextRegister g_device_##DEVICE_NAME##_reg( \
DEVICE_NAME, [](const DeviceContextKey &device_context_key) { \
return std::make_shared<DEVICE_CONTEXT_CLASS>(device_context_key); \
});
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_MANAGER_H_

@ -0,0 +1,154 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License"){}
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "runtime/hardware/gpu/gpu_device_context.h"
#include <dlfcn.h>
#include "runtime/device/gpu/kernel_info_setter.h"
#include "runtime/device/gpu/gpu_kernel_build.h"
#include "runtime/device/gpu/gpu_device_address.h"
#include "runtime/device/gpu/gpu_memory_manager.h"
#include "runtime/device/gpu/gpu_memory_allocator.h"
#include "runtime/device/gpu/gpu_stream_assign.h"
#include "runtime/device/gpu/distribution/collective_init.h"
#include "runtime/device/gpu/gpu_device_manager.h"
#include "runtime/device/gpu/gpu_buffer_mgr.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/gpu/gpu_common.h"
namespace mindspore {
namespace device {
namespace gpu {
bool GPUDeviceContext::Initialize() {
if (initialized_ == true) {
GPUMemoryAllocator::GetInstance().CheckMaxDeviceMemory();
return true;
}
// Set device id and initialize device resource.
bool ret = InitDevice();
if (!ret) {
MS_LOG(ERROR) << "GPU InitDevice failed.";
return ret;
}
// Initialize memory pool.
mem_manager_ = std::make_shared<GPUMemoryManager>();
MS_EXCEPTION_IF_NULL(mem_manager_);
mem_manager_->MallocDeviceMemory();
// Initialize NCCL.
const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
bool collective_inited = CollectiveInitializer::instance().collective_inited();
if (collective_inited && collective_handle_ != nullptr) {
auto init_nccl_comm_funcptr =
reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
(*init_nccl_comm_funcptr)();
}
initialized_ = true;
return ret;
}
bool GPUDeviceContext::InitDevice() {
if (GPUDeviceManager::GetInstance().device_count() <= 0) {
MS_LOG(ERROR) << "No GPU device found.";
return false;
}
if (!GPUDeviceManager::GetInstance().is_device_id_init()) {
if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_context_key_.device_id_)) {
MS_LOG(ERROR) << "Failed to set current device id: " << SizeToInt(device_context_key_.device_id_);
return false;
}
}
// Initialize device resource, such as stream, cudnn and cublas handle.
GPUDeviceManager::GetInstance().InitDevice();
auto stream = GPUDeviceManager::GetInstance().default_stream();
if (stream == nullptr) {
MS_LOG(ERROR) << "No default CUDA stream found.";
return false;
}
streams_.push_back(stream);
return true;
}
void GPUDeviceContext::Destroy() {
// Release GPU buffer manager resource
if (GpuBufferMgr::GetInstance().IsInit()) {
if (!GpuBufferMgr::GetInstance().IsClosed() && !GpuBufferMgr::GetInstance().CloseNotify()) {
MS_LOG(EXCEPTION) << "Could not close gpu data queue.";
}
CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue.");
}
// Release stream, cudnn and cublas handle, etc.
GPUDeviceManager::GetInstance().ReleaseDevice();
// Release device memory
if (mem_manager_ != nullptr) {
mem_manager_->FreeDeviceMemory();
mem_manager_ = nullptr;
}
// Clean GPU cache kernels which is generated by AKG
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
if (!(context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG))) {
kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance();
MS_EXCEPTION_IF_NULL(bin_map);
bin_map->RemoveKernelCache();
}
}
bool GPUDeviceContext::AllocateMemory(const DeviceAddressPtr &address, size_t size) const {
return mem_manager_->MallocMemFromMemPool(address, size);
}
void GPUDeviceContext::FreeMemory(const DeviceAddressPtr &address) const { mem_manager_->FreeMemFromMemPool(address); }
bool GPUDeviceContext::AllocateContinuousMemory(const DeviceAddressPtrList &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const {
return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
}
void GPUDeviceContext::SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {
for (const auto &node : nodes) {
SetKernelInfo(node);
}
}
void GPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const { CreateGPUKernel(nodes); }
bool GPUDeviceContext::LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) const {
MS_EXCEPTION_IF_NULL(kernel_mod);
return kernel_mod->Launch(inputs, workspace, outputs, streams_.front());
}
bool GPUDeviceContext::SyncStream(size_t stream_id) {
if (stream_id >= streams_.size()) {
MS_LOG(EXCEPTION) << "The stream_id: " << stream_id << " is greater than stream array size: " << streams_.size();
}
return GPUDeviceManager::GetInstance().SyncStream(streams_[stream_id]);
}
MS_REGISTER_DEVICE(kGPUDevice, GPUDeviceContext);
} // namespace gpu
} // namespace device
} // namespace mindspore

@ -0,0 +1,64 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_GPU_DEVICE_CONTEXT_H_
#define MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_GPU_DEVICE_CONTEXT_H_
#include <vector>
#include <memory>
#include "runtime/hardware/device_context.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/device/memory_manager.h"
namespace mindspore {
namespace device {
namespace gpu {
class GPUDeviceContext : public DeviceContext {
public:
explicit GPUDeviceContext(const DeviceContextKey &device_context_key)
: DeviceContext(device_context_key), mem_manager_(nullptr), initialized_(false) {}
~GPUDeviceContext() override = default;
// Set device id and initialize device resource, such as stream, cudnn and cublas handle.
bool Initialize() override;
// Release device memory, stream, cudnn and cublas handle, etc.
void Destroy() override;
bool AllocateMemory(const DeviceAddressPtr &address, size_t size) const override;
void FreeMemory(const DeviceAddressPtr &address) const override;
bool AllocateContinuousMemory(const DeviceAddressPtrList &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const override;
void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const override;
void CreateKernel(const std::vector<CNodePtr> &nodes) const override;
bool LaunchKernel(KernelMod *kernel_mod, const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace, const std::vector<AddressPtr> &outputs) const override;
bool SyncStream(size_t stream_id = 0) override;
private:
DISABLE_COPY_AND_ASSIGN(GPUDeviceContext);
bool InitDevice();
std::shared_ptr<MemoryManager> mem_manager_;
std::vector<void *> streams_;
bool initialized_;
};
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_GPU_GPU_DEVICE_CONTEXT_H_
Loading…
Cancel
Save