fix cudnn workspace size problem during inference. (#26021)

test=develop
revert-24895-update_cub
Zhaolong Xing 5 years ago committed by GitHub
parent 1f74b94d3f
commit 50f149a48e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -216,6 +216,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
VLOG(3) << "cuDNN forward algo " << algo;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
if (workspace_size_in_bytes > workspace_size_limit)
workspace_size_limit = workspace_size_in_bytes;
} else {
std::function<cudnnConvolutionFwdAlgo_t()> search_func =
[&]() -> cudnnConvolutionFwdAlgo_t {

Loading…
Cancel
Save