|
|
|
@ -31,14 +31,6 @@ struct CPUPlace {
|
|
|
|
|
inline bool operator!=(const CPUPlace &) const { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct MKLDNNPlace {
|
|
|
|
|
MKLDNNPlace() {}
|
|
|
|
|
|
|
|
|
|
// needed for variant equality comparison
|
|
|
|
|
inline bool operator==(const MKLDNNPlace &) const { return true; }
|
|
|
|
|
inline bool operator!=(const MKLDNNPlace &) const { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct CUDAPlace {
|
|
|
|
|
CUDAPlace() : CUDAPlace(0) {}
|
|
|
|
|
explicit CUDAPlace(int d) : device(d) {}
|
|
|
|
@ -53,37 +45,21 @@ struct CUDAPlace {
|
|
|
|
|
int device;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct CUDNNPlace : public CUDAPlace {
|
|
|
|
|
CUDNNPlace() : CUDAPlace() {}
|
|
|
|
|
explicit CUDNNPlace(int d) : CUDAPlace(d) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct IsCUDAPlace : public boost::static_visitor<bool> {
|
|
|
|
|
bool operator()(const CPUPlace &) const { return false; }
|
|
|
|
|
bool operator()(const MKLDNNPlace &) const { return false; }
|
|
|
|
|
bool operator()(const CUDAPlace &gpu) const { return true; }
|
|
|
|
|
bool operator()(const CUDNNPlace &) const { return true; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct IsMKLDNNPlace : public boost::static_visitor<bool> {
|
|
|
|
|
bool operator()(const MKLDNNPlace &) const { return true; }
|
|
|
|
|
bool operator()(const CPUPlace &) const { return false; }
|
|
|
|
|
bool operator()(const CUDAPlace &) const { return false; }
|
|
|
|
|
bool operator()(const CUDNNPlace &) const { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
typedef boost::variant<CUDNNPlace, CUDAPlace, CPUPlace, MKLDNNPlace> Place;
|
|
|
|
|
typedef boost::variant<CUDAPlace, CPUPlace> Place;
|
|
|
|
|
|
|
|
|
|
void set_place(const Place &);
|
|
|
|
|
const Place &get_place();
|
|
|
|
|
|
|
|
|
|
const CUDAPlace default_gpu();
|
|
|
|
|
const CPUPlace default_cpu();
|
|
|
|
|
const MKLDNNPlace default_mkldnn();
|
|
|
|
|
|
|
|
|
|
bool is_gpu_place(const Place &);
|
|
|
|
|
bool is_cpu_place(const Place &);
|
|
|
|
|
bool is_mkldnn_place(const Place &);
|
|
|
|
|
bool places_are_same_class(const Place &, const Place &);
|
|
|
|
|
|
|
|
|
|
std::ostream &operator<<(std::ostream &, const Place &);
|
|
|
|
|