|
|
|
@ -166,10 +166,23 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// TODO(dangqingqing) simplify the following code by SearchAlgorithm in
|
|
|
|
|
// conv_cudnn_helper.h
|
|
|
|
|
if ((!exhaustive_search) && (!half_float)) {
|
|
|
|
|
#if CUDNN_VERSION >= 7001
|
|
|
|
|
using perf_t = cudnnConvolutionFwdAlgoPerf_t;
|
|
|
|
|
int perf_count;
|
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_output_desc, kNUM_CUDNN_FWD_ALGS, &perf_count,
|
|
|
|
|
perf_results.get()));
|
|
|
|
|
algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
|
#else
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "cuDNN forward algo " << algo;
|
|
|
|
|
} else if (exhaustive_search && (!half_float)) {
|
|
|
|
|
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
|
|
|
|
@ -388,6 +401,37 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else if (FLAGS_cudnn_deterministic) {
|
|
|
|
|
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
|
|
} else {
|
|
|
|
|
#if CUDNN_VERSION >= 7001
|
|
|
|
|
using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
|
|
|
|
|
int perf_count;
|
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(
|
|
|
|
|
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
|
|
|
|
handle, cudnn_filter_desc,
|
|
|
|
|
// dyDesc: Handle to the previously initialized input
|
|
|
|
|
// differential
|
|
|
|
|
// tensor descriptor.
|
|
|
|
|
cudnn_output_grad_desc, cudnn_conv_desc,
|
|
|
|
|
// dxDesc: Handle to the previously initialized output tensor
|
|
|
|
|
// descriptor.
|
|
|
|
|
cudnn_input_desc, kNUM_CUDNN_BWD_DATA_ALGS, &perf_count,
|
|
|
|
|
perf_results.get()));
|
|
|
|
|
data_algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
|
int stride_dim = input->dims().size() - 2;
|
|
|
|
|
bool blacklist =
|
|
|
|
|
std::any_of(strides.begin(), strides.begin() + stride_dim,
|
|
|
|
|
[=](int n) { return n != 1; });
|
|
|
|
|
if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
|
|
|
|
|
perf_results[best_algo_idx].algo) ==
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
|
|
|
|
|
static_cast<cudnnConvolutionBwdDataAlgo_t>(
|
|
|
|
|
perf_results[best_algo_idx].algo) ==
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
|
|
|
|
|
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
handle, cudnn_filter_desc,
|
|
|
|
@ -400,6 +444,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
cudnn_input_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &data_algo));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
|
|
@ -437,12 +482,27 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else if (FLAGS_cudnn_deterministic) {
|
|
|
|
|
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
|
|
|
|
} else {
|
|
|
|
|
#if CUDNN_VERSION >= 7001
|
|
|
|
|
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
|
|
|
|
|
int perf_count;
|
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(
|
|
|
|
|
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
|
|
|
|
|
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_output_grad_desc,
|
|
|
|
|
cudnn_conv_desc, cudnn_filter_desc, kNUM_CUDNN_BWD_FILTER_ALGS,
|
|
|
|
|
&perf_count, perf_results.get()));
|
|
|
|
|
filter_algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
|
#else
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_output_grad_desc,
|
|
|
|
|
cudnn_conv_desc, cudnn_filter_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &filter_algo));
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
|
|
|