@ -139,9 +139,8 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo ;
auto handle = dev_ctx . cudnn_handle ( ) ;
bool half_float = false ;
# if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
// Tensor core is supported since the volta GPU and
// is only enabled when input and filter data are float16
@ -160,9 +159,9 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
VLOG ( 5 ) < < " NOT use cudnn_tensor_op_math " ;
}
# endif
Tensor cudnn_workspace ;
void * cudnn_workspace_ptr = nullptr ;
auto handle = dev_ctx . cudnn_handle ( ) ;
auto workspace_handle = dev_ctx . cudnn_workspace_handle ( ) ;
auto x_dims = framework : : vectorize ( input - > dims ( ) ) ;
auto f_dims = framework : : vectorize ( filter - > dims ( ) ) ;
if ( ( ! exhaustive_search ) & & ( ! half_float ) ) {
@ -174,12 +173,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
} else if ( exhaustive_search & & ( ! half_float ) ) {
AlgorithmsCache < cudnnConvolutionFwdAlgo_t > & algo_cache =
ctx . GetKernelConfig < AlgorithmsCache < cudnnConvolutionFwdAlgo_t > > ( 0 ) ;
cudnn_workspace =
ctx . AllocateTmpTensor < int8_t , platform : : CUDADeviceContext > (
framework : : make_ddim (
{ static_cast < int64_t > ( workspace_size_limit ) } ) ,
dev_ctx ) ;
cudnn_workspace_ptr = static_cast < void * > ( cudnn_workspace . data < int8_t > ( ) ) ;
algo = algo_cache . GetAlgorithm (
x_dims , f_dims , strides , paddings , dilations , 0 , [ & ] ( ) {
@ -187,13 +180,16 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
std : : array < cudnnConvolutionFwdAlgoPerf_t , kNUM_CUDNN_FWD_ALGS >
fwd_perf_stat ;
CUDNN_ENFORCE (
platform : : dynload : : cudnnFindConvolutionForwardAlgorithmEx (
handle , cudnn_input_desc , input_data , cudnn_filter_desc ,
filter_data , cudnn_conv_desc , cudnn_output_desc ,
output_data , kNUM_CUDNN_FWD_ALGS , & returned_algo_count ,
fwd_perf_stat . data ( ) , cudnn_workspace_ptr ,
workspace_size_limit ) ) ;
auto cudnn_find_func = [ & ] ( void * cudnn_workspace ) {
CUDNN_ENFORCE (
platform : : dynload : : cudnnFindConvolutionForwardAlgorithmEx (
handle , cudnn_input_desc , input_data , cudnn_filter_desc ,
filter_data , cudnn_conv_desc , cudnn_output_desc ,
output_data , kNUM_CUDNN_FWD_ALGS , & returned_algo_count ,
fwd_perf_stat . data ( ) , cudnn_workspace ,
workspace_size_limit ) ) ;
} ;
workspace_handle . RunFuncSync ( cudnn_find_func , workspace_size_limit ) ;
VLOG ( 3 ) < < " Perf result: (algo: stat, time, memory) " ;
for ( int i = 0 ; i < returned_algo_count ; + + i ) {
@ -219,14 +215,13 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
" workspace_size to be allocated exceeds the limit " ) ;
// Allocate on GPU memory
if ( ! cudnn_workspace_ptr ) {
cudnn_workspace =
ctx . AllocateTmpTensor < int8_t , platform : : CUDADeviceContext > (
framework : : make_ddim (
{ static_cast < int64_t > ( workspace_size_in_bytes ) } ) ,
dev_ctx ) ;
cudnn_workspace_ptr = static_cast < void * > ( cudnn_workspace . data < int8_t > ( ) ) ;
}
Tensor cudnn_workspace =
ctx . AllocateTmpTensor < int8_t , platform : : CUDADeviceContext > (
framework : : make_ddim (
{ static_cast < int64_t > ( workspace_size_in_bytes ) } ) ,
dev_ctx ) ;
void * cudnn_workspace_ptr =
static_cast < void * > ( cudnn_workspace . data < int8_t > ( ) ) ;
// ------------------- cudnn conv forward ---------------------
ScalingParamType < T > alpha = 1.0f , beta = 0.0f ;
for ( int i = 0 ; i < groups ; i + + ) {