|
|
|
@ -70,16 +70,16 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
framework::LibraryType library{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
library = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
if (library == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
library = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
@ -91,15 +91,15 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
"input and filter data type should be consistent");
|
|
|
|
|
|
|
|
|
|
if (input_data_type == framework::proto::VarType::FP16) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
|
|
|
|
|
PADDLE_ENFORCE_EQ(library, framework::LibraryType::kCUDNN,
|
|
|
|
|
"float16 can only be used when CUDNN is used");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
|
|
|
|
|
library_);
|
|
|
|
|
framework::DataLayout layout = framework::StringToDataLayout(data_format);
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
|
|
|
|
|
library);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|