|
|
|
@ -21,6 +21,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/operator_kernel_configs.h"
|
|
|
|
|
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_desc.h"
|
|
|
|
|
// #include "paddle/fluid/platform/device_context.h"
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -89,7 +90,43 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
|
|
|
|
|
return out;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using framework::AlgorithmsCache;
|
|
|
|
|
// ConvSearchCache using framework::AlgorithmsCache to search
|
|
|
|
|
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
|
|
|
|
|
// cudnnConvolutionBwdFilterAlgo_t
|
|
|
|
|
class ConvSearchCache {
|
|
|
|
|
public:
|
|
|
|
|
static ConvSearchCache& Instance() {
|
|
|
|
|
static ConvSearchCache instance;
|
|
|
|
|
return instance;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetForward() {
|
|
|
|
|
return &forward_cache_;
|
|
|
|
|
}
|
|
|
|
|
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* GetBackwardData() {
|
|
|
|
|
return &backward_data_cache_;
|
|
|
|
|
}
|
|
|
|
|
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>*
|
|
|
|
|
GetBackwardFilter() {
|
|
|
|
|
return &backward_filter_cache_;
|
|
|
|
|
}
|
|
|
|
|
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() {
|
|
|
|
|
return &fusion_forward_cache_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
ConvSearchCache() {}
|
|
|
|
|
~ConvSearchCache() {}
|
|
|
|
|
ConvSearchCache(const ConvSearchCache&) {}
|
|
|
|
|
ConvSearchCache& operator=(const ConvSearchCache&) {}
|
|
|
|
|
|
|
|
|
|
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
|
|
|
|
|
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>
|
|
|
|
|
backward_data_cache_;
|
|
|
|
|
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>
|
|
|
|
|
backward_filter_cache_;
|
|
|
|
|
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct ConvArgs {
|
|
|
|
|
cudnnHandle_t handle;
|
|
|
|
@ -97,6 +134,7 @@ struct ConvArgs {
|
|
|
|
|
platform::FilterDescriptor wdesc;
|
|
|
|
|
platform::ConvolutionDescriptor cdesc;
|
|
|
|
|
const framework::Tensor *x, *w, *o;
|
|
|
|
|
cudnnDataType_t cudnn_dtype;
|
|
|
|
|
|
|
|
|
|
// strides
|
|
|
|
|
std::vector<int> s;
|
|
|
|
@ -107,8 +145,9 @@ struct ConvArgs {
|
|
|
|
|
|
|
|
|
|
ConvArgs(const framework::Tensor* x, const framework::Tensor* w,
|
|
|
|
|
const framework::Tensor* o, const std::vector<int> s,
|
|
|
|
|
const std::vector<int> p, const std::vector<int> d)
|
|
|
|
|
: x(x), w(w), o(o), s(s), p(p), d(d) {}
|
|
|
|
|
const std::vector<int> p, const std::vector<int> d,
|
|
|
|
|
cudnnDataType_t dtype)
|
|
|
|
|
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename perf_t>
|
|
|
|
@ -121,7 +160,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
|
|
|
|
|
bool deterministic, int algo_cache_id,
|
|
|
|
|
bool deterministic,
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
bool has_got_workspace_size = true;
|
|
|
|
@ -183,22 +222,24 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
|
|
|
|
|
#endif
|
|
|
|
|
VLOG(3) << "choose algo " << algo;
|
|
|
|
|
} else {
|
|
|
|
|
AlgorithmsCache<algo_t>& algo_cache =
|
|
|
|
|
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
|
|
|
|
|
|
|
|
|
|
auto& temp = ctx.cuda_device_context();
|
|
|
|
|
AlgorithmsCache<algo_t>& algo_cache =
|
|
|
|
|
*(ConvSearchCache::Instance().GetForward());
|
|
|
|
|
|
|
|
|
|
auto x_dims = framework::vectorize(args.x->dims());
|
|
|
|
|
auto w_dims = framework::vectorize(args.w->dims());
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
|
|
|
|
|
<< algo_cache_id << ", x_dims:" << x_dims
|
|
|
|
|
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
|
|
|
|
|
<< args.p << ", args.d" << args.d;
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
|
|
|
|
|
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
|
|
|
|
|
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
|
|
|
|
|
|
|
|
|
|
algo = algo_cache.GetAlgorithm(
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0,
|
|
|
|
|
static_cast<int64_t>(args.cudnn_dtype), [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
|
|
|
|
|
@ -244,7 +285,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
|
|
|
|
|
bool deterministic, int algo_cache_id,
|
|
|
|
|
bool deterministic,
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
|
|
|
|
@ -321,22 +362,23 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
|
|
|
|
|
} else if (deterministic) {
|
|
|
|
|
return CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
|
|
} else {
|
|
|
|
|
AlgorithmsCache<algo_t>& algo_cache =
|
|
|
|
|
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
|
|
|
|
|
|
|
|
|
|
AlgorithmsCache<algo_t>& algo_cache =
|
|
|
|
|
*(ConvSearchCache::Instance().GetBackwardData());
|
|
|
|
|
|
|
|
|
|
auto x_dims = framework::vectorize(args.x->dims());
|
|
|
|
|
auto w_dims = framework::vectorize(args.w->dims());
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
|
|
|
|
|
<< algo_cache_id << ", x_dims:" << x_dims
|
|
|
|
|
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
|
|
|
|
|
<< args.p << ", args.d" << args.d;
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t"
|
|
|
|
|
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
|
|
|
|
|
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
|
|
|
|
|
|
|
|
|
|
algo = algo_cache.GetAlgorithm(
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0,
|
|
|
|
|
static_cast<int64_t>(args.cudnn_dtype), [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
|
|
|
|
|
@ -385,7 +427,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
static algo_t Find(const ConvArgs& args, bool exhaustive_search,
|
|
|
|
|
bool deterministic, int algo_cache_id,
|
|
|
|
|
bool deterministic,
|
|
|
|
|
const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
bool exhaustive = (exhaustive_search) & (dtype != CUDNN_DATA_HALF);
|
|
|
|
@ -449,22 +491,22 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
|
|
|
|
|
} else if (deterministic) {
|
|
|
|
|
return CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
|
|
|
|
} else {
|
|
|
|
|
AlgorithmsCache<algo_t>& algo_cache =
|
|
|
|
|
ctx.GetKernelConfig<AlgorithmsCache<algo_t>>(algo_cache_id);
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
|
|
|
|
|
AlgorithmsCache<algo_t>& algo_cache =
|
|
|
|
|
*(ConvSearchCache::Instance().GetBackwardFilter());
|
|
|
|
|
|
|
|
|
|
auto x_dims = framework::vectorize(args.x->dims());
|
|
|
|
|
auto w_dims = framework::vectorize(args.w->dims());
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t algo_cache_id:"
|
|
|
|
|
<< algo_cache_id << ", x_dims:" << x_dims
|
|
|
|
|
<< ", w_dims:" << w_dims << ", args.s" << args.s << ", args.p"
|
|
|
|
|
<< args.p << ", args.d" << args.d;
|
|
|
|
|
VLOG(10) << "cudnnConvolutionFwdAlgoPerf_t:"
|
|
|
|
|
<< ", x_dims:" << x_dims << ", w_dims:" << w_dims << ", args.s"
|
|
|
|
|
<< args.s << ", args.p" << args.p << ", args.d" << args.d;
|
|
|
|
|
|
|
|
|
|
algo = algo_cache.GetAlgorithm(
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0, [&]() {
|
|
|
|
|
x_dims, w_dims, args.s, args.p, args.d, 0,
|
|
|
|
|
static_cast<int64_t>(args.cudnn_dtype), [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|
std::array<perf_t, kNUM_CUDNN_FWD_ALGS> perf_stat;
|
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
|
|
|
|
|