|
|
|
@ -31,7 +31,7 @@ namespace platform {
|
|
|
|
|
class DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
virtual ~DeviceContext() {}
|
|
|
|
|
virtual Place GetPlace() const = 0;
|
|
|
|
|
virtual Place place() const = 0;
|
|
|
|
|
|
|
|
|
|
template <typename DeviceType>
|
|
|
|
|
DeviceType* get_eigen_device() const;
|
|
|
|
@ -39,14 +39,13 @@ class DeviceContext {
|
|
|
|
|
|
|
|
|
|
class CPUDeviceContext : public DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); }
|
|
|
|
|
CPUDeviceContext();
|
|
|
|
|
CPUDeviceContext(CPUPlace);
|
|
|
|
|
virtual ~CPUDeviceContext() {}
|
|
|
|
|
|
|
|
|
|
Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); }
|
|
|
|
|
Eigen::DefaultDevice* eigen_device() const;
|
|
|
|
|
|
|
|
|
|
Place GetPlace() const override {
|
|
|
|
|
Place retv = CPUPlace();
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|
Place place() const override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
|
|
|
|
@ -54,119 +53,51 @@ class CPUDeviceContext : public DeviceContext {
|
|
|
|
|
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
|
|
|
|
|
class GPUPlaceGuard {
|
|
|
|
|
class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
|
|
|
|
|
if (previous_ != new_place) {
|
|
|
|
|
paddle::platform::SetDeviceId(new_place.device);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
explicit CUDADeviceContext(GPUPlace);
|
|
|
|
|
virtual ~CUDADeviceContext();
|
|
|
|
|
|
|
|
|
|
~GPUPlaceGuard() { paddle::platform::SetDeviceId(previous_.device); }
|
|
|
|
|
/*! \brief Wait for all operations completion in the stream. */
|
|
|
|
|
void wait() const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
GPUPlace previous_;
|
|
|
|
|
};
|
|
|
|
|
/*! \brief Return CUDA stream in the device context. */
|
|
|
|
|
cudaStream_t stream() const;
|
|
|
|
|
|
|
|
|
|
class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
|
|
|
|
|
GPUPlaceGuard guard(gpu_place_);
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamCreate(&stream_), "cudaStreamCreate failed");
|
|
|
|
|
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
|
|
|
|
|
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Place GetPlace() const override {
|
|
|
|
|
Place retv = GPUPlace();
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Wait() {
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
|
|
|
|
|
"cudaStreamSynchronize failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudaStream_t stream() { return stream_; }
|
|
|
|
|
|
|
|
|
|
Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); }
|
|
|
|
|
|
|
|
|
|
cublasHandle_t cublas_handle() {
|
|
|
|
|
if (!blas_handle_) {
|
|
|
|
|
GPUPlaceGuard guard(gpu_place_);
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::dynload::cublasCreate(&blas_handle_),
|
|
|
|
|
"cublasCreate failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
paddle::platform::dynload::cublasSetStream(blas_handle_, stream_),
|
|
|
|
|
"cublasSetStream failed");
|
|
|
|
|
}
|
|
|
|
|
return blas_handle_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudnnHandle_t cudnn_handle() {
|
|
|
|
|
if (!dnn_handle_) {
|
|
|
|
|
GPUPlaceGuard guard(gpu_place_);
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::dynload::cudnnCreate(&dnn_handle_),
|
|
|
|
|
"cudnnCreate failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
paddle::platform::dynload::cudnnSetStream(dnn_handle_, stream_),
|
|
|
|
|
"cudnnSetStream failed");
|
|
|
|
|
}
|
|
|
|
|
return dnn_handle_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
curandGenerator_t curand_generator() {
|
|
|
|
|
if (!rand_generator_) {
|
|
|
|
|
GPUPlaceGuard guard(gpu_place_);
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::dynload::curandCreateGenerator(
|
|
|
|
|
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT),
|
|
|
|
|
"curandCreateGenerator failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
paddle::platform::dynload::curandSetPseudoRandomGeneratorSeed(
|
|
|
|
|
rand_generator_, random_seed_),
|
|
|
|
|
"curandSetPseudoRandomGeneratorSeed failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
paddle::platform::dynload::curandSetStream(rand_generator_, stream_),
|
|
|
|
|
"curandSetStream failed");
|
|
|
|
|
}
|
|
|
|
|
return rand_generator_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~CUDADeviceContext() {
|
|
|
|
|
Wait();
|
|
|
|
|
if (blas_handle_) {
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::dynload::cublasDestroy(blas_handle_),
|
|
|
|
|
"cublasDestroy failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dnn_handle_) {
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::dynload::cudnnDestroy(dnn_handle_),
|
|
|
|
|
"cudnnDestroy failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (rand_generator_) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
paddle::platform::dynload::curandDestroyGenerator(rand_generator_),
|
|
|
|
|
"curandDestroyGenerator failed");
|
|
|
|
|
}
|
|
|
|
|
eigen_stream_.reset();
|
|
|
|
|
eigen_device_.reset();
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamDestroy(stream_), "cudaStreamDestroy failed");
|
|
|
|
|
}
|
|
|
|
|
/*! \brief Return place in the device context. */
|
|
|
|
|
Place place() const override;
|
|
|
|
|
|
|
|
|
|
/*! \brief Return eigen device in the device context. */
|
|
|
|
|
Eigen::GpuDevice* eigen_device() const;
|
|
|
|
|
|
|
|
|
|
// clang-format off
|
|
|
|
|
/*! \brief Return cublas handle in the device context. */
|
|
|
|
|
cublasHandle_t cublas_handle ();
|
|
|
|
|
|
|
|
|
|
/*! \brief Return cudnn handle in the device context. */
|
|
|
|
|
cudnnHandle_t cudnn_handle ();
|
|
|
|
|
|
|
|
|
|
/*! \brief Return curand handle in the device context. */
|
|
|
|
|
curandGenerator_t curand_generator();
|
|
|
|
|
// clang-format on
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
GPUPlace gpu_place_;
|
|
|
|
|
cudaStream_t stream_;
|
|
|
|
|
GPUPlace place_;
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
|
|
|
|
|
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
|
|
|
|
|
|
|
|
|
|
cublasHandle_t blas_handle_{nullptr};
|
|
|
|
|
private:
|
|
|
|
|
uint64_t seed_;
|
|
|
|
|
|
|
|
|
|
cudnnHandle_t dnn_handle_{nullptr};
|
|
|
|
|
cudaStream_t stream_;
|
|
|
|
|
|
|
|
|
|
int random_seed_;
|
|
|
|
|
curandGenerator_t rand_generator_{nullptr};
|
|
|
|
|
// clang-format off
|
|
|
|
|
cudnnHandle_t cudnn_handle_ = nullptr;
|
|
|
|
|
cublasHandle_t cublas_handle_ = nullptr;
|
|
|
|
|
curandGenerator_t curand_generator_ = nullptr;
|
|
|
|
|
// clang-format on
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|