|
|
|
@ -32,18 +32,18 @@ using namespace platform;
|
|
|
|
|
* 1111 -> FP64, GPUPlace, kNCHW, kMKLDNN
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
std::array<proto::DataType, 2> kDataType = {proto::DataType::FP32,
|
|
|
|
|
proto::DataType::FP64};
|
|
|
|
|
std::array<proto::DataType, 2> kDataType = {
|
|
|
|
|
{proto::DataType::FP32, proto::DataType::FP64}};
|
|
|
|
|
|
|
|
|
|
std::array<Place, 2> kPlace = {CPUPlace(), CUDAPlace(0)};
|
|
|
|
|
std::array<Place, 2> kPlace = {{CPUPlace(), CUDAPlace(0)}};
|
|
|
|
|
|
|
|
|
|
std::array<DataLayout, 2> kDataLayout = {
|
|
|
|
|
std::array<DataLayout, 2> kDataLayout = {{
|
|
|
|
|
DataLayout::kNHWC, DataLayout::kNCHW,
|
|
|
|
|
};
|
|
|
|
|
}};
|
|
|
|
|
|
|
|
|
|
std::array<LibraryType, 2> kLibraryType = {
|
|
|
|
|
std::array<LibraryType, 2> kLibraryType = {{
|
|
|
|
|
LibraryType::kPlain, LibraryType::kMKLDNN,
|
|
|
|
|
};
|
|
|
|
|
}};
|
|
|
|
|
|
|
|
|
|
OpKernelType GenFromBit(const std::vector<bool> bits) {
|
|
|
|
|
return OpKernelType(kDataType[bits[0]], kPlace[bits[1]], kDataLayout[bits[2]],
|
|
|
|
|