|
|
|
|
@ -162,7 +162,20 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
|
workspace_size = GetWorkspaceSize(args, algo);
|
|
|
|
|
|
|
|
|
|
if (workspace_size > workspace_size_limit) {
|
|
|
|
|
#if CUDNN_VERSION >= 8000
|
|
|
|
|
workspace_size_limit = workspace_size;
|
|
|
|
|
#else
|
|
|
|
|
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
|
|
|
|
|
"the workspace size request("
|
|
|
|
|
<< workspace_size << ") exceeds the limit("
|
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.odesc.desc(),
|
|
|
|
|
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
@ -291,8 +304,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
|
#endif
|
|
|
|
|
workspace_size = GetWorkspaceSize(args, algo);
|
|
|
|
|
if (workspace_size > workspace_size_limit) {
|
|
|
|
|
workspace_size_limit = workspace_size;
|
|
|
|
|
has_got_workspace_size = false;
|
|
|
|
|
#if CUDNN_VERSION >= 8000
|
|
|
|
|
// There is no cudnnGetConvolutionBackwardDataAlgorithm in CUDNN 8
|
|
|
|
|
// version.
|
|
|
|
|
workspace_size_limit = workspace_size;
|
|
|
|
|
#else
|
|
|
|
|
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
|
|
|
|
|
"the workspace size request("
|
|
|
|
|
<< workspace_size << ") exceeds the limit("
|
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.idesc.desc(),
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
|