@ -75,6 +75,11 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
framework : : OpKernelType ConvOp : : GetExpectedKernelType (
framework : : OpKernelType ConvOp : : GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const {
const framework : : ExecutionContext & ctx ) const {
framework : : LibraryType library { framework : : LibraryType : : kPlain } ;
framework : : LibraryType library { framework : : LibraryType : : kPlain } ;
std : : string data_format = ctx . Attr < std : : string > ( " data_format " ) ;
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework : : DataLayout layout = framework : : StringToDataLayout ( data_format ) ;
# ifdef PADDLE_WITH_CUDA
# ifdef PADDLE_WITH_CUDA
if ( platform : : CanCUDNNBeUsed ( ctx ) ) {
if ( platform : : CanCUDNNBeUsed ( ctx ) ) {
library = framework : : LibraryType : : kCUDNN ;
library = framework : : LibraryType : : kCUDNN ;
@ -84,6 +89,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
if ( library = = framework : : LibraryType : : kPlain & &
if ( library = = framework : : LibraryType : : kPlain & &
platform : : CanMKLDNNBeUsed ( ctx ) ) {
platform : : CanMKLDNNBeUsed ( ctx ) ) {
library = framework : : LibraryType : : kMKLDNN ;
library = framework : : LibraryType : : kMKLDNN ;
layout = framework : : DataLayout : : kMKLDNN ;
}
}
# endif
# endif
@ -99,9 +105,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
" float16 can only be used when CUDNN is used " ) ;
" float16 can only be used when CUDNN is used " ) ;
}
}
std : : string data_format = ctx . Attr < std : : string > ( " data_format " ) ;
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework : : DataLayout layout = framework : : StringToDataLayout ( data_format ) ;
return framework : : OpKernelType ( input_data_type , ctx . GetPlace ( ) , layout ,
return framework : : OpKernelType ( input_data_type , ctx . GetPlace ( ) , layout ,
library ) ;
library ) ;
}
}
@ -309,6 +312,10 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
framework : : OpKernelType ConvOpGrad : : GetExpectedKernelType (
framework : : OpKernelType ConvOpGrad : : GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const {
const framework : : ExecutionContext & ctx ) const {
framework : : LibraryType library_ { framework : : LibraryType : : kPlain } ;
framework : : LibraryType library_ { framework : : LibraryType : : kPlain } ;
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
std : : string data_format = ctx . Attr < std : : string > ( " data_format " ) ;
framework : : DataLayout layout_ = framework : : StringToDataLayout ( data_format ) ;
# ifdef PADDLE_WITH_CUDA
# ifdef PADDLE_WITH_CUDA
if ( platform : : CanCUDNNBeUsed ( ctx ) ) {
if ( platform : : CanCUDNNBeUsed ( ctx ) ) {
library_ = framework : : LibraryType : : kCUDNN ;
library_ = framework : : LibraryType : : kCUDNN ;
@ -318,12 +325,10 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
if ( library_ = = framework : : LibraryType : : kPlain & &
if ( library_ = = framework : : LibraryType : : kPlain & &
platform : : CanMKLDNNBeUsed ( ctx ) ) {
platform : : CanMKLDNNBeUsed ( ctx ) ) {
library_ = framework : : LibraryType : : kMKLDNN ;
library_ = framework : : LibraryType : : kMKLDNN ;
layout_ = framework : : DataLayout : : kMKLDNN ;
}
}
# endif
# endif
std : : string data_format = ctx . Attr < std : : string > ( " data_format " ) ;
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework : : DataLayout layout_ = framework : : StringToDataLayout ( data_format ) ;
return framework : : OpKernelType (
return framework : : OpKernelType (
framework : : ToDataType ( ctx . Input < Tensor > ( " Input " ) - > type ( ) ) , ctx . GetPlace ( ) ,
framework : : ToDataType ( ctx . Input < Tensor > ( " Input " ) - > type ( ) ) , ctx . GetPlace ( ) ,
layout_ , library_ ) ;
layout_ , library_ ) ;