|
|
|
@ -31,7 +31,7 @@ namespace platform {
|
|
|
|
|
class DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
virtual ~DeviceContext() {}
|
|
|
|
|
virtual Place place() const = 0;
|
|
|
|
|
virtual Place GetPlace() const = 0;
|
|
|
|
|
|
|
|
|
|
template <typename DeviceType>
|
|
|
|
|
DeviceType* get_eigen_device() const;
|
|
|
|
@ -45,7 +45,7 @@ class CPUDeviceContext : public DeviceContext {
|
|
|
|
|
|
|
|
|
|
Eigen::DefaultDevice* eigen_device() const;
|
|
|
|
|
|
|
|
|
|
Place place() const override;
|
|
|
|
|
Place GetPlace() const override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
|
|
|
|
@ -59,13 +59,13 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
virtual ~CUDADeviceContext();
|
|
|
|
|
|
|
|
|
|
/*! \brief Wait for all operations completion in the stream. */
|
|
|
|
|
void wait() const;
|
|
|
|
|
void Wait() const;
|
|
|
|
|
|
|
|
|
|
/*! \brief Return CUDA stream in the device context. */
|
|
|
|
|
cudaStream_t stream() const;
|
|
|
|
|
|
|
|
|
|
/*! \brief Return place in the device context. */
|
|
|
|
|
Place place() const override;
|
|
|
|
|
Place GetPlace() const override;
|
|
|
|
|
|
|
|
|
|
/*! \brief Return eigen device in the device context. */
|
|
|
|
|
Eigen::GpuDevice* eigen_device() const;
|
|
|
|
|