@ -29,13 +29,14 @@ namespace platform {
using framework : : Tensor ;
template < typename T >
cudnnDataType_t ToCudnnDataType ( const T & t ) {
inline cudnnDataType_t ToCudnnDataType ( const T & t ) {
auto type = framework : : ToDataType ( t ) ;
return ToCudnnDataType ( type ) ;
}
template < >
cudnnDataType_t ToCudnnDataType ( const framework : : proto : : VarType : : Type & t ) {
inline cudnnDataType_t ToCudnnDataType (
const framework : : proto : : VarType : : Type & t ) {
cudnnDataType_t type = CUDNN_DATA_FLOAT ;
switch ( t ) {
case framework : : proto : : VarType : : FP16 :
@ -59,14 +60,14 @@ class ActivationDescriptor {
struct Deleter {
void operator ( ) ( T * t ) {
if ( t ! = nullptr ) {
PADDLE _ENFORCE( dynload : : cudnnDestroyActivationDescriptor ( t ) ) ;
CUDNN _ENFORCE( dynload : : cudnnDestroyActivationDescriptor ( t ) ) ;
t = nullptr ;
}
}
} ;
ActivationDescriptor ( ) {
T * raw_ptr ;
PADDLE _ENFORCE( dynload : : cudnnCreateActivationDescriptor ( & raw_ptr ) ) ;
CUDNN _ENFORCE( dynload : : cudnnCreateActivationDescriptor ( & raw_ptr ) ) ;
desc_ . reset ( raw_ptr ) ;
}
template < typename T >
@ -88,14 +89,14 @@ class TensorDescriptor {
struct Deleter {
void operator ( ) ( T * t ) {
if ( t ! = nullptr ) {
PADDLE _ENFORCE( dynload : : cudnnDestroyTensorDescriptor ( t ) ) ;
CUDNN _ENFORCE( dynload : : cudnnDestroyTensorDescriptor ( t ) ) ;
t = nullptr ;
}
}
} ;
TensorDescriptor ( ) {
T * raw_ptr ;
PADDLE _ENFORCE( dynload : : cudnnCreateTensorDescriptor ( & raw_ptr ) ) ;
CUDNN _ENFORCE( dynload : : cudnnCreateTensorDescriptor ( & raw_ptr ) ) ;
desc_ . reset ( raw_ptr ) ;
}
T * desc ( ) { return desc_ . get ( ) ; }
@ -111,7 +112,7 @@ class TensorDescriptor {
if ( groups > 1 ) {
dims_with_group [ 1 ] = dims_with_group [ 1 ] / groups ;
}
PADDLE _ENFORCE( dynload : : cudnnSetTensorNdDescriptor (
CUDNN _ENFORCE( dynload : : cudnnSetTensorNdDescriptor (
desc_ . get ( ) , ToCudnnDataType ( tensor . type ( ) ) , dims_with_group . size ( ) ,
dims_with_group . data ( ) , strides . data ( ) ) ) ;
}
@ -120,5 +121,83 @@ class TensorDescriptor {
std : : unique_ptr < T , Deleter > desc_ ;
} ;
class FilterDescriptor {
public :
using T = cudnnFilterStruct ;
struct Deleter {
void operator ( ) ( T * t ) {
if ( t ! = nullptr ) {
CUDNN_ENFORCE ( dynload : : cudnnDestroyFilterDescriptor ( t ) ) ;
t = nullptr ;
}
}
} ;
FilterDescriptor ( ) {
T * raw_ptr ;
CUDNN_ENFORCE ( dynload : : cudnnCreateFilterDescriptor ( & raw_ptr ) ) ;
desc_ . reset ( raw_ptr ) ;
}
T * desc ( ) { return desc_ . get ( ) ; }
T * desc ( ) const { return desc_ . get ( ) ; }
void set ( const Tensor & tensor , const cudnnTensorFormat_t format ,
const int groups = 1 ) {
auto dims = framework : : vectorize2int ( tensor . dims ( ) ) ;
if ( groups > 1 ) {
dims [ 1 ] = dims [ 1 ] / groups ;
}
CUDNN_ENFORCE ( dynload : : cudnnSetFilterNdDescriptor (
desc_ . get ( ) , ToCudnnDataType ( tensor . type ( ) ) , format , dims . size ( ) ,
dims . data ( ) ) ) ;
}
private :
std : : unique_ptr < T , Deleter > desc_ ;
} ;
class ConvolutionDescriptor {
public :
using T = cudnnConvolutionStruct ;
struct Deleter {
void operator ( ) ( T * t ) {
if ( t ! = nullptr ) {
CUDNN_ENFORCE ( dynload : : cudnnDestroyConvolutionDescriptor ( t ) ) ;
t = nullptr ;
}
}
} ;
ConvolutionDescriptor ( ) {
T * raw_ptr ;
CUDNN_ENFORCE ( dynload : : cudnnCreateConvolutionDescriptor ( & raw_ptr ) ) ;
desc_ . reset ( raw_ptr ) ;
}
T * desc ( ) { return desc_ . get ( ) ; }
T * desc ( ) const { return desc_ . get ( ) ; }
void set ( cudnnDataType_t dtype , const std : : vector < int > & pads ,
const std : : vector < int > & strides , const std : : vector < int > & dilations ,
const int groups = 1 ) {
cudnnDataType_t compute_type =
( dtype = = CUDNN_DATA_DOUBLE ) ? CUDNN_DATA_DOUBLE : CUDNN_DATA_FLOAT ;
T * desc = desc_ . get ( ) ;
CUDNN_ENFORCE ( dynload : : cudnnSetConvolutionNdDescriptor (
desc , pads . size ( ) , pads . data ( ) , strides . data ( ) , dilations . data ( ) ,
CUDNN_CROSS_CORRELATION , compute_type ) ) ;
CUDNN_ENFORCE ( platform : : dynload : : cudnnSetConvolutionMathType (
desc , CUDNN_DEFAULT_MATH ) ) ;
# if CUDNN_VERSION_MIN(7, 0, 1)
CUDNN_ENFORCE (
platform : : dynload : : cudnnSetConvolutionGroupCount ( desc , groups ) ) ;
if ( dtype = = CUDNN_DATA_HALF ) {
CUDNN_ENFORCE ( platform : : dynload : : cudnnSetConvolutionMathType (
desc , CUDNN_TENSOR_OP_MATH ) ) ;
}
# endif
}
private :
std : : unique_ptr < T , Deleter > desc_ ;
} ;
} // namespace platform
} // namespace paddle