|
|
@ -31,7 +31,7 @@ class DeviceContext {
|
|
|
|
virtual Place GetPlace() const = 0;
|
|
|
|
virtual Place GetPlace() const = 0;
|
|
|
|
|
|
|
|
|
|
|
|
template <typename DeviceType>
|
|
|
|
template <typename DeviceType>
|
|
|
|
inline DeviceType get_eigen_device();
|
|
|
|
DeviceType get_eigen_device();
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class CPUDeviceContext : public DeviceContext {
|
|
|
|
class CPUDeviceContext : public DeviceContext {
|
|
|
@ -52,11 +52,6 @@ class CPUDeviceContext : public DeviceContext {
|
|
|
|
Eigen::DefaultDevice* eigen_device_{nullptr};
|
|
|
|
Eigen::DefaultDevice* eigen_device_{nullptr};
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
Eigen::DefaultDevice DeviceContext::get_eigen_device<Eigen::DefaultDevice>() {
|
|
|
|
|
|
|
|
return dynamic_cast<CPUDeviceContext*>(this)->eigen_device();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
|
|
|
|
|
|
|
class GPUPlaceGuard {
|
|
|
|
class GPUPlaceGuard {
|
|
|
@ -183,10 +178,6 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
curandGenerator_t rand_generator_{nullptr};
|
|
|
|
curandGenerator_t rand_generator_{nullptr};
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
Eigen::GpuDevice DeviceContext::get_eigen_device<Eigen::GpuDevice>() {
|
|
|
|
|
|
|
|
return dynamic_cast<CUDADeviceContext*>(this)->eigen_device();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace platform
|
|
|
|
} // namespace platform
|
|
|
|