|
|
|
@ -14,11 +14,11 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/operator_kernel_configs.h"
|
|
|
|
|
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_desc.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -57,16 +57,57 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
|
bool deterministic, int algo_cache_id,
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
bool has_got_workspace_size = true;
|
|
|
|
|
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
|
|
|
|
|
|
|
|
|
|
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
|
|
|
|
|
|
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
|
algo_t algo;
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
|
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
|
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
|
} else {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
|
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
|
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
if (!exhaustive) {
|
|
|
|
|
#if CUDNN_VERSION >= 7001
|
|
|
|
|
int perf_count;
|
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(new perf_t[kNUM_CUDNN_FWD_ALGS]);
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
|
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
|
|
|
|
|
args.odesc.desc(), kNUM_CUDNN_FWD_ALGS, &perf_count,
|
|
|
|
|
perf_results.get()));
|
|
|
|
|
algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
|
workspace_size = GetWorkspaceSize(args, algo);
|
|
|
|
|
|
|
|
|
|
if (workspace_size > workspace_size_limit) {
|
|
|
|
|
has_got_workspace_size = false;
|
|
|
|
|
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
|
|
|
|
|
"the workspace size request("
|
|
|
|
|
<< workspace_size << ") exceeds the limit("
|
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
|
}
|
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.odesc.desc(),
|
|
|
|
|
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, workspace_size_limit,
|
|
|
|
|
&algo));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
|
|
|
|
|
args.handle, args.idesc.desc(), args.wdesc.desc(), args.cdesc.desc(),
|
|
|
|
|
args.odesc.desc(), CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
|
#endif
|
|
|
|
|
VLOG(3) << "choose algo " << algo;
|
|
|
|
|
} else {
|
|
|
|
|
AlgorithmsCache<algo_t>& algo_cache =
|
|
|
|
@ -128,15 +169,72 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
|
|
|
|
|
|
|
|
|
|
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
|
|
|
|
|
|
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
|
bool has_got_workspace_size = true;
|
|
|
|
|
algo_t algo;
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
|
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
|
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
|
} else {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
|
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
|
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
if (!exhaustive && !deterministic) {
|
|
|
|
|
#if CUDNN_VERSION >= 7001
|
|
|
|
|
int perf_count;
|
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(
|
|
|
|
|
new perf_t[kNUM_CUDNN_BWD_DATA_ALGS]);
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm_v7(
|
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.idesc.desc(), kNUM_CUDNN_BWD_DATA_ALGS,
|
|
|
|
|
&perf_count, perf_results.get()));
|
|
|
|
|
algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
|
|
|
|
|
|
#if CUDNN_VERSION < 7500
|
|
|
|
|
int stride_dim = args.x->dims().size() - 2;
|
|
|
|
|
bool blacklist = std::any_of(args.s.begin(), args.s.begin() + stride_dim,
|
|
|
|
|
[=](int n) { return n != 1; });
|
|
|
|
|
if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(
|
|
|
|
|
perf_results[best_algo_idx].algo) ==
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING ||
|
|
|
|
|
static_cast<cudnnConvolutionBwdDataAlgo_t>(
|
|
|
|
|
perf_results[best_algo_idx].algo) ==
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
|
|
|
|
|
algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
workspace_size = GetWorkspaceSize(args, algo);
|
|
|
|
|
if (workspace_size > workspace_size_limit) {
|
|
|
|
|
has_got_workspace_size = false;
|
|
|
|
|
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
|
|
|
|
|
"the workspace size request("
|
|
|
|
|
<< workspace_size << ") exceeds the limit("
|
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
|
}
|
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.idesc.desc(),
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
args.handle, args.wdesc.desc(), args.idesc.desc(), args.cdesc.desc(),
|
|
|
|
|
args.odesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(), args.cdesc.desc(),
|
|
|
|
|
args.idesc.desc(), CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
|
#endif
|
|
|
|
|
} else if (deterministic) {
|
|
|
|
|
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
|
|
} else {
|
|
|
|
@ -186,8 +284,8 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
|
|
|
args.handle, args.wdesc.desc(), args.idesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.odesc.desc(), algo, &workspace_size));
|
|
|
|
|
args.handle, args.wdesc.desc(), args.odesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.idesc.desc(), algo, &workspace_size));
|
|
|
|
|
return workspace_size;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -203,17 +301,61 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
|
|
|
|
|
|
|
|
|
|
size_t workspace_size_limit = FLAGS_conv_workspace_size_limit * 1024 * 1024;
|
|
|
|
|
size_t workspace_size = 0;
|
|
|
|
|
bool has_got_workspace_size = true;
|
|
|
|
|
|
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
if (dev_ctx.GetComputeCapability() >= 70 && dtype == CUDNN_DATA_HALF) {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
|
args.cdesc.desc(), CUDNN_TENSOR_OP_MATH));
|
|
|
|
|
VLOG(5) << "use cudnn_tensor_op_math";
|
|
|
|
|
} else {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
|
|
|
|
|
args.cdesc.desc(), CUDNN_DEFAULT_MATH));
|
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
algo_t algo;
|
|
|
|
|
if (!exhaustive && !deterministic) {
|
|
|
|
|
#if CUDNN_VERSION >= 7001
|
|
|
|
|
using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
|
|
|
|
|
int perf_count;
|
|
|
|
|
int best_algo_idx = 0;
|
|
|
|
|
std::unique_ptr<perf_t[]> perf_results(
|
|
|
|
|
new perf_t[kNUM_CUDNN_BWD_FILTER_ALGS]);
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm_v7(
|
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(), kNUM_CUDNN_BWD_FILTER_ALGS,
|
|
|
|
|
&perf_count, perf_results.get()));
|
|
|
|
|
algo = (perf_results.get())[best_algo_idx].algo;
|
|
|
|
|
workspace_size = GetWorkspaceSize(args, algo);
|
|
|
|
|
if (workspace_size > workspace_size_limit) {
|
|
|
|
|
has_got_workspace_size = false;
|
|
|
|
|
VLOG(1) << "Fallback to non-v7 method to find conv algorithm becasue "
|
|
|
|
|
"the workspace size request("
|
|
|
|
|
<< workspace_size << ") exceeds the limit("
|
|
|
|
|
<< workspace_size_limit << ")";
|
|
|
|
|
}
|
|
|
|
|
if (!has_got_workspace_size) {
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(),
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
|
args.handle, args.idesc.desc(), args.odesc.desc(),
|
|
|
|
|
args.cdesc.desc(), args.wdesc.desc(),
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &algo));
|
|
|
|
|
#endif
|
|
|
|
|
} else if (deterministic) {
|
|
|
|
|
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
|
|
|
|
} else {
|
|
|
|
|