|
|
|
@ -429,7 +429,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0,
|
|
|
|
|
static_cast<int64_t>(args.cudnn_dtype), [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
|
std::array<perf_t, kNUM_CUDNN_BWD_DATA_ALGS> perf_stat;
|
|
|
|
|
|
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
@ -561,7 +561,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0,
|
|
|
|
|
static_cast<int64_t>(args.cudnn_dtype), [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
|
std::array<perf_t, kNUM_CUDNN_BWD_FILTER_ALGS> perf_stat;
|
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::
|
|
|
|
|