|
|
|
@ -76,12 +76,16 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
|
|
|
|
|
const std::string& name) {
|
|
|
|
|
framework::LibraryType library{framework::LibraryType::kPlain};
|
|
|
|
|
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
auto it1 = oper.Attrs().find("use_cudnn");
|
|
|
|
|
if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
// FIXME(liuwei1031) temporarily disable the code to unblock users
|
|
|
|
|
// TODO(liuwei1031) figure out the reason behind
|
|
|
|
|
// https://github.com/PaddlePaddle/Paddle/issues/16096
|
|
|
|
|
// and re-enable this in the future
|
|
|
|
|
// #ifdef PADDLE_WITH_CUDA
|
|
|
|
|
// auto it1 = oper.Attrs().find("use_cudnn");
|
|
|
|
|
// if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
// library = framework::LibraryType::kCUDNN;
|
|
|
|
|
// }
|
|
|
|
|
// #endif
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
auto it = oper.Attrs().find("use_mkldnn");
|
|
|
|
|
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
|
|
|
|
|