|
|
@ -133,12 +133,14 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
|
|
|
|
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
|
|
|
|
|
|
|
|
CUDNN_TENSOR_OP_MATH));
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
|
|
|
|
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
|
|
|
|
|
|
|
|
CUDNN_DEFAULT_MATH));
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
@ -148,10 +150,11 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
int perf_count;
|
|
|
|
int perf_count;
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
|
|
|
|
platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
|
|
|
|
args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, &perf_count,
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(),
|
|
|
|
perf_results.get()));
|
|
|
|
args.cdesc.desc(), args.odesc.desc(), kNUM_CUDNN_FWD_ALGS,
|
|
|
|
|
|
|
|
&perf_count, perf_results.get()));
|
|
|
|
algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
workspace_size = GetWorkspaceSize(args, algo);
|
|
|
|
workspace_size = GetWorkspaceSize(args, algo);
|
|
|
|
|
|
|
|
|
|
|
@ -163,17 +166,20 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(),
|
|
|
|
platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
args.cdesc.desc(), args.odesc.desc(),
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(),
|
|
|
|
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit,
|
|
|
|
args.cdesc.desc(), args.odesc.desc(),
|
|
|
|
&algo));
|
|
|
|
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
|
|
|
|
platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
args.odesc.desc(), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(),
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
args.cdesc.desc(), args.odesc.desc(),
|
|
|
|
|
|
|
|
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
VLOG(3) << "choose algo " << algo;
|
|
|
|
VLOG(3) << "choose algo " << algo;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -197,7 +203,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
|
|
|
|
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
|
|
|
|
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
|
|
|
|
args.handle, args.idesc.desc(), args.x->data<T>(),
|
|
|
|
args.handle, args.idesc.desc(), args.x->data<T>(),
|
|
|
|
args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
|
|
|
|
args.wdesc.desc(), args.w->data<T>(), args.cdesc.desc(),
|
|
|
@ -223,9 +229,10 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
|
|
|
|
|
|
|
|
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
|
|
|
|
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
|
|
|
|
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
|
|
|
|
args.odesc.desc(), algo, &workspace_size));
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(),
|
|
|
|
|
|
|
|
args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size));
|
|
|
|
return workspace_size;
|
|
|
|
return workspace_size;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
@ -249,12 +256,14 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
|
|
|
|
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
|
|
|
|
|
|
|
|
CUDNN_TENSOR_OP_MATH));
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
|
|
|
|
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
|
|
|
|
|
|
|
|
CUDNN_DEFAULT_MATH));
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
@ -265,7 +274,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(
|
|
|
|
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
|
|
|
|
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS,
|
|
|
|
args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS,
|
|
|
@ -294,7 +303,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
args.cdesc.desc(), args.idesc.desc(),
|
|
|
|
args.cdesc.desc(), args.idesc.desc(),
|
|
|
@ -302,10 +311,12 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(), args.cdesc.desc(),
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
args.idesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
args.cdesc.desc(), args.idesc.desc(),
|
|
|
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
} else if (deterministic) {
|
|
|
|
} else if (deterministic) {
|
|
|
|
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
|
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
@ -330,7 +341,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
|
|
|
|
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::
|
|
|
|
platform::dynload::
|
|
|
|
cudnnFindConvolutionBackwardDataAlgorithmEx(
|
|
|
|
cudnnFindConvolutionBackwardDataAlgorithmEx(
|
|
|
|
args.handle, args.wdesc.desc(), args.w->data<T>(),
|
|
|
|
args.handle, args.wdesc.desc(), args.w->data<T>(),
|
|
|
@ -359,7 +370,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
|
|
|
|
|
|
|
|
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
|
|
|
|
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size));
|
|
|
|
args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size));
|
|
|
@ -385,12 +396,14 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
|
|
|
|
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
|
|
|
|
|
|
|
|
CUDNN_TENSOR_OP_MATH));
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
|
|
|
|
platform::dynload::cudnnSetConvolutionMathType(args.cdesc.desc(),
|
|
|
|
|
|
|
|
CUDNN_DEFAULT_MATH));
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
@ -403,7 +416,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(
|
|
|
|
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
|
|
|
|
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS,
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS,
|
|
|
@ -418,7 +431,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(),
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(),
|
|
|
@ -426,7 +439,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#else
|
|
|
|
#else
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(),
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(),
|
|
|
@ -455,7 +468,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
int returned_algo_count;
|
|
|
|
int returned_algo_count;
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::
|
|
|
|
platform::dynload::
|
|
|
|
cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
|
|
|
cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
|
|
|
args.handle, args.idesc.desc(), args.x->data<T>(),
|
|
|
|
args.handle, args.idesc.desc(), args.x->data<T>(),
|
|
|
@ -483,7 +496,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
|
|
|
|
|
|
|
|
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
|
|
|
|
static size_t GetWorkspaceSize(const ConvArgs& args, algo_t algo) {
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(), algo, &workspace_size));
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(), algo, &workspace_size));
|
|
|
|