|
|
|
@ -31,6 +31,14 @@ struct CPUPlace {
|
|
|
|
|
inline bool operator!=(const CPUPlace &) const { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct MKLDNNPlace : public CPUPlace {
|
|
|
|
|
MKLDNNPlace() {}
|
|
|
|
|
|
|
|
|
|
// needed for variant equality comparison
|
|
|
|
|
inline bool operator==(const MKLDNNPlace &) const { return true; }
|
|
|
|
|
inline bool operator!=(const MKLDNNPlace &) const { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct GPUPlace {
|
|
|
|
|
GPUPlace() : GPUPlace(0) {}
|
|
|
|
|
explicit GPUPlace(int d) : device(d) {}
|
|
|
|
@ -45,14 +53,21 @@ struct GPUPlace {
|
|
|
|
|
|
|
|
|
|
struct IsGPUPlace : public boost::static_visitor<bool> {
|
|
|
|
|
bool operator()(const CPUPlace &) const { return false; }
|
|
|
|
|
bool operator()(const MKLDNNPlace &) const { return false; }
|
|
|
|
|
bool operator()(const GPUPlace &gpu) const { return true; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct IsMKLDNNPlace : public boost::static_visitor<bool> {
|
|
|
|
|
bool operator()(const MKLDNNPlace &) const { return true; }
|
|
|
|
|
bool operator()(const CPUPlace &) const { return false; }
|
|
|
|
|
bool operator()(const GPUPlace &) const { return false; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Define the max number of Place in bit length. i.e., the max number of places
|
|
|
|
|
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
|
|
|
|
|
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
|
|
|
|
|
|
|
|
|
|
typedef boost::variant<GPUPlace, CPUPlace> Place;
|
|
|
|
|
typedef boost::variant<GPUPlace, CPUPlace, MKLDNNPlace> Place;
|
|
|
|
|
|
|
|
|
|
// static check number of place types is less equal than
|
|
|
|
|
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
|
|
|
|
@ -65,9 +80,11 @@ const Place &get_place();
|
|
|
|
|
|
|
|
|
|
const GPUPlace default_gpu();
|
|
|
|
|
const CPUPlace default_cpu();
|
|
|
|
|
const MKLDNNPlace default_mkldnn();
|
|
|
|
|
|
|
|
|
|
bool is_gpu_place(const Place &);
|
|
|
|
|
bool is_cpu_place(const Place &);
|
|
|
|
|
bool is_mkldnn_place(const Place &);
|
|
|
|
|
bool places_are_same_class(const Place &, const Place &);
|
|
|
|
|
|
|
|
|
|
std::ostream &operator<<(std::ostream &, const Place &);
|
|
|
|
|