|
|
@ -43,6 +43,11 @@ struct GPUPlace {
|
|
|
|
int device;
|
|
|
|
int device;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct CudnnPlace : public GPUPlace {
|
|
|
|
|
|
|
|
CudnnPlace() : GPUPlace() {}
|
|
|
|
|
|
|
|
explicit CudnnPlace(int d) : GPUPlace(d) {}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
struct IsGPUPlace : public boost::static_visitor<bool> {
|
|
|
|
struct IsGPUPlace : public boost::static_visitor<bool> {
|
|
|
|
bool operator()(const CPUPlace &) const { return false; }
|
|
|
|
bool operator()(const CPUPlace &) const { return false; }
|
|
|
|
bool operator()(const GPUPlace &gpu) const { return true; }
|
|
|
|
bool operator()(const GPUPlace &gpu) const { return true; }
|
|
|
@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
|
|
|
|
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
|
|
|
|
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
|
|
|
|
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
|
|
|
|
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
|
|
|
|
|
|
|
|
|
|
|
|
typedef boost::variant<GPUPlace, CPUPlace> Place;
|
|
|
|
typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace> Place;
|
|
|
|
|
|
|
|
|
|
|
|
// static check number of place types is less equal than
|
|
|
|
// static check number of place types is less equal than
|
|
|
|
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
|
|
|
|
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
|
|
|
|