|
|
|
@ -216,6 +216,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
perf_results.get()));
|
|
|
|
|
algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
|
VLOG(3) << "cuDNN forward algo " << algo;
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_output_desc, algo, &workspace_size_in_bytes));
|
|
|
|
|
if (workspace_size_in_bytes > workspace_size_limit)
|
|
|
|
|
workspace_size_limit = workspace_size_in_bytes;
|
|
|
|
|
} else {
|
|
|
|
|
std::function<cudnnConvolutionFwdAlgo_t()> search_func =
|
|
|
|
|
[&]() -> cudnnConvolutionFwdAlgo_t {
|
|
|
|
|