|
|
|
@ -20,19 +20,13 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/platform/dynload/cudnn.h"
|
|
|
|
|
#include "paddle/platform/dynload/curand.h"
|
|
|
|
|
#define EIGEN_USE_GPU
|
|
|
|
|
#include "paddle/platform/device_context.h"
|
|
|
|
|
#include "paddle/platform/place.h"
|
|
|
|
|
#include "unsupported/Eigen/CXX11/Tensor"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace platform {
|
|
|
|
|
|
|
|
|
|
class CUDADeviceContext;
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
|
|
|
|
|
return static_cast<CUDADeviceContext*>(this)->eigen_handle();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class GPUPlaceGuard {
|
|
|
|
|
public:
|
|
|
|
|
explicit GPUPlaceGuard(GPUPlace new_place) : previous_(GetCurrentDeviceId()) {
|
|
|
|
@ -49,7 +43,7 @@ class GPUPlaceGuard {
|
|
|
|
|
|
|
|
|
|
class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
explicit Device(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
|
|
|
|
|
explicit CUDADeviceContext(const GPUPlace gpu_place) : gpu_place_(gpu_place) {
|
|
|
|
|
GPUPlaceGuard guard(gpu_place_);
|
|
|
|
|
paddle::platform::throw_on_error(cudaStreamCreate(&stream_),
|
|
|
|
|
"cudaStreamCreate failed");
|
|
|
|
@ -156,5 +150,10 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
int random_seed_;
|
|
|
|
|
curandGenerator_t rand_generator_{nullptr};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
|
|
|
|
|
return dynamic_cast<CUDADeviceContext*>(this)->eigen_device();
|
|
|
|
|
}
|
|
|
|
|
} // namespace platform
|
|
|
|
|
} // namespace paddle
|
|
|
|
|