|
|
|
@ -31,16 +31,16 @@ class DeviceContext {
|
|
|
|
|
virtual Place GetPlace() const = 0;
|
|
|
|
|
|
|
|
|
|
template <typename DeviceType>
|
|
|
|
|
DeviceType get_eigen_device();
|
|
|
|
|
DeviceType* get_eigen_device();
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CPUDeviceContext : public DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
Eigen::DefaultDevice eigen_device() {
|
|
|
|
|
Eigen::DefaultDevice* eigen_device() {
|
|
|
|
|
if (!eigen_device_) {
|
|
|
|
|
eigen_device_ = new Eigen::DefaultDevice();
|
|
|
|
|
eigen_device_.reset(new Eigen::DefaultDevice());
|
|
|
|
|
}
|
|
|
|
|
return *eigen_device_;
|
|
|
|
|
return eigen_device_.get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Place GetPlace() const override {
|
|
|
|
@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
Eigen::DefaultDevice* eigen_device_{nullptr};
|
|
|
|
|
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
@ -74,8 +74,8 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
GPUPlaceGuard guard(gpu_place_);
|
|
|
|
|
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
|
|
|
|
|
"cudaStreamCreate failed");
|
|
|
|
|
eigen_stream_ = new Eigen::CudaStreamDevice(&stream_);
|
|
|
|
|
eigen_device_ = new Eigen::GpuDevice(eigen_stream_);
|
|
|
|
|
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
|
|
|
|
|
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Place GetPlace() const override {
|
|
|
|
@ -90,7 +90,7 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
|
|
|
|
|
cudaStream_t stream() { return stream_; }
|
|
|
|
|
|
|
|
|
|
Eigen::GpuDevice eigen_device() { return *eigen_device_; }
|
|
|
|
|
Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); }
|
|
|
|
|
|
|
|
|
|
cublasHandle_t cublas_handle() {
|
|
|
|
|
if (!blas_handle_) {
|
|
|
|
@ -155,10 +155,8 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
rand_generator_) == CURAND_STATUS_SUCCESS,
|
|
|
|
|
"curandDestroyGenerator failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
delete eigen_stream_;
|
|
|
|
|
delete eigen_device_;
|
|
|
|
|
|
|
|
|
|
eigen_stream_.reset();
|
|
|
|
|
eigen_device_.reset();
|
|
|
|
|
paddle::platform::throw_on_error(cudaStreamDestroy(stream_),
|
|
|
|
|
"cudaStreamDestroy failed");
|
|
|
|
|
}
|
|
|
|
@ -167,8 +165,8 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
GPUPlace gpu_place_;
|
|
|
|
|
cudaStream_t stream_;
|
|
|
|
|
|
|
|
|
|
Eigen::CudaStreamDevice* eigen_stream_;
|
|
|
|
|
Eigen::GpuDevice* eigen_device_;
|
|
|
|
|
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_;
|
|
|
|
|
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
|
|
|
|
|
|
|
|
|
|
cublasHandle_t blas_handle_{nullptr};
|
|
|
|
|
|
|
|
|
|