You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
199 lines
6.2 KiB
199 lines
6.2 KiB
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include "paddle/platform/device_context.h"
|
|
#include "paddle/memory/memory.h"
|
|
|
|
namespace paddle {
|
|
namespace platform {
|
|
|
|
DeviceContextPool* DeviceContextPool::pool = nullptr;
|
|
|
|
const platform::DeviceContext* DeviceContextPool::Borrow(
|
|
const platform::Place& place) {
|
|
auto it = device_contexts_.find(place);
|
|
if (it == device_contexts_.end()) {
|
|
PADDLE_THROW(
|
|
"'Place' is not supported, Please re-compile with WITH_GPU "
|
|
"option");
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
std::vector<const platform::DeviceContext*> DeviceContextPool::Borrow(
|
|
const std::vector<platform::Place>& places) {
|
|
PADDLE_ENFORCE_GT(places.size(), 0);
|
|
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
|
|
std::vector<const platform::DeviceContext*> borrowed_contexts;
|
|
for (auto& place : places) {
|
|
auto it = device_contexts_.find(place);
|
|
if (it != device_contexts_.end()) {
|
|
borrowed_contexts.emplace_back(it->second);
|
|
} else {
|
|
PADDLE_THROW(
|
|
"'Place' is not supported, Please re-compile with WITH_GPU "
|
|
"option");
|
|
}
|
|
}
|
|
return borrowed_contexts;
|
|
}
|
|
|
|
DeviceContextPool::DeviceContextPool(
|
|
const std::vector<platform::Place>& places) {
|
|
PADDLE_ENFORCE_GT(places.size(), 0);
|
|
for (size_t i = 0; i < places.size(); i++) {
|
|
if (platform::is_cpu_place(places[i])) {
|
|
device_contexts_.emplace(places[i],
|
|
new platform::CPUDeviceContext(
|
|
boost::get<platform::CPUPlace>(places[i])));
|
|
} else if (platform::is_gpu_place(places[i])) {
|
|
#ifdef PADDLE_WITH_CUDA
|
|
device_contexts_.emplace(places[i],
|
|
new platform::CUDADeviceContext(
|
|
boost::get<platform::CUDAPlace>(places[i])));
|
|
#else
|
|
PADDLE_THROW(
|
|
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
|
|
"option");
|
|
#endif
|
|
}
|
|
}
|
|
}
|
|
|
|
CPUDeviceContext::CPUDeviceContext() {
|
|
eigen_device_.reset(new Eigen::DefaultDevice());
|
|
}
|
|
|
|
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
|
|
eigen_device_.reset(new Eigen::DefaultDevice());
|
|
}
|
|
|
|
Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
|
|
return eigen_device_.get();
|
|
}
|
|
|
|
Place CPUDeviceContext::GetPlace() const { return place_; }
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
class EigenCudaStreamDevice : public Eigen::StreamInterface {
|
|
public:
|
|
EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
|
|
Eigen::initializeDeviceProp();
|
|
}
|
|
~EigenCudaStreamDevice() override {}
|
|
|
|
void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
|
|
stream_ = cuda_stream;
|
|
place_ = place;
|
|
device_prop_ = &Eigen::m_deviceProperties[place.device];
|
|
}
|
|
|
|
const cudaStream_t& stream() const override { return *stream_; }
|
|
|
|
const cudaDeviceProp& deviceProperties() const override {
|
|
return *device_prop_;
|
|
}
|
|
|
|
void* allocate(size_t num_bytes) const override {
|
|
return paddle::memory::Alloc(place_, num_bytes);
|
|
}
|
|
|
|
void deallocate(void* buffer) const override {
|
|
paddle::memory::Free(place_, buffer);
|
|
}
|
|
|
|
void* scratchpad() const override {
|
|
if (scratch_ == NULL) {
|
|
scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
|
|
}
|
|
return scratch_;
|
|
}
|
|
|
|
unsigned int* semaphore() const override {
|
|
if (semaphore_ == NULL) {
|
|
char* scratch =
|
|
static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
|
|
semaphore_ = reinterpret_cast<unsigned int*>(scratch);
|
|
PADDLE_ENFORCE(
|
|
cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
|
|
}
|
|
return semaphore_;
|
|
}
|
|
|
|
private:
|
|
CUDAPlace place_;
|
|
const cudaStream_t* stream_; // not owned;
|
|
const cudaDeviceProp* device_prop_; // not owned;
|
|
mutable void* scratch_;
|
|
mutable unsigned int* semaphore_;
|
|
};
|
|
|
|
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
|
|
SetDeviceId(place_.device);
|
|
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
|
|
eigen_stream_.reset(new EigenCudaStreamDevice());
|
|
eigen_stream_->Reinitialize(&stream_, place);
|
|
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
|
|
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
|
|
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
|
|
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
|
|
}
|
|
|
|
CUDADeviceContext::~CUDADeviceContext() {
|
|
SetDeviceId(place_.device);
|
|
Wait();
|
|
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
|
|
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
|
|
eigen_stream_.reset();
|
|
eigen_device_.reset();
|
|
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
|
|
}
|
|
|
|
Place CUDADeviceContext::GetPlace() const { return place_; }
|
|
|
|
void CUDADeviceContext::Wait() const {
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
|
|
PADDLE_ENFORCE(cudaGetLastError());
|
|
}
|
|
|
|
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
|
|
return eigen_device_.get();
|
|
}
|
|
|
|
cublasHandle_t CUDADeviceContext::cublas_handle() const {
|
|
return cublas_handle_;
|
|
}
|
|
|
|
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
|
|
|
|
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
|
|
|
|
CUDNNDeviceContext::CUDNNDeviceContext(CUDAPlace place)
|
|
: CUDADeviceContext(place) {
|
|
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
|
|
}
|
|
|
|
CUDNNDeviceContext::~CUDNNDeviceContext() {
|
|
SetDeviceId(boost::get<CUDAPlace>(GetPlace()).device);
|
|
Wait();
|
|
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
|
|
}
|
|
|
|
cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; }
|
|
|
|
#endif
|
|
|
|
} // namespace platform
|
|
} // namespace paddle
|