|
|
|
@ -34,13 +34,14 @@ class DeviceContext {
|
|
|
|
|
|
|
|
|
|
template <typename DeviceType>
|
|
|
|
|
DeviceType* get_eigen_device() const;
|
|
|
|
|
|
|
|
|
|
virtual void Wait() const {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CPUDeviceContext : public DeviceContext {
|
|
|
|
|
public:
|
|
|
|
|
CPUDeviceContext();
|
|
|
|
|
explicit CPUDeviceContext(CPUPlace place);
|
|
|
|
|
virtual ~CPUDeviceContext() {}
|
|
|
|
|
|
|
|
|
|
Eigen::DefaultDevice* eigen_device() const;
|
|
|
|
|
|
|
|
|
@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext {
|
|
|
|
|
virtual ~CUDADeviceContext();
|
|
|
|
|
|
|
|
|
|
/*! \brief Wait for all operations completion in the stream. */
|
|
|
|
|
void Wait() const;
|
|
|
|
|
void Wait() const override;
|
|
|
|
|
|
|
|
|
|
/*! \brief Return place in the device context. */
|
|
|
|
|
Place GetPlace() const override;
|
|
|
|
|