|
|
|
@ -137,6 +137,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// ------------------- cudnn conv algorithm ---------------------
|
|
|
|
|
cudnnConvolutionFwdAlgo_t algo;
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
|
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
|
|
|
|
|
|
|
|
|
|
bool half_float = false;
|
|
|
|
|
#if CUDA_VERSION >= 9000 && CUDNN_VERSION_MIN(7, 0, 1)
|
|
|
|
@ -157,8 +158,6 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
VLOG(5) << "NOT use cudnn_tensor_op_math";
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
Tensor cudnn_workspace;
|
|
|
|
|
void* cudnn_workspace_ptr = nullptr;
|
|
|
|
|
|
|
|
|
|
auto x_dims = framework::vectorize(input->dims());
|
|
|
|
|
auto f_dims = framework::vectorize(filter->dims());
|
|
|
|
@ -181,26 +180,21 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
.Var(kCUDNNFwdAlgoCache)
|
|
|
|
|
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
|
|
|
|
|
}
|
|
|
|
|
cudnn_workspace =
|
|
|
|
|
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(workspace_size_limit)}),
|
|
|
|
|
dev_ctx);
|
|
|
|
|
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
|
|
|
|
|
|
|
|
|
|
algo = algo_cache->GetAlgorithm(
|
|
|
|
|
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
|
|
|
|
|
fwd_perf_stat;
|
|
|
|
|
|
|
|
|
|
auto cudnn_find_func = [&](void* cudnn_workspace) {
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
|
|
|
|
|
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
|
|
|
|
|
filter_data, cudnn_conv_desc, cudnn_output_desc,
|
|
|
|
|
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
|
|
|
|
|
fwd_perf_stat.data(), cudnn_workspace_ptr,
|
|
|
|
|
fwd_perf_stat.data(), cudnn_workspace,
|
|
|
|
|
workspace_size_limit));
|
|
|
|
|
};
|
|
|
|
|
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "Perf result: (algo: stat, time, memory)";
|
|
|
|
|
for (int i = 0; i < returned_algo_count; ++i) {
|
|
|
|
@ -225,23 +219,17 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
|
|
|
|
|
"workspace_size to be allocated exceeds the limit");
|
|
|
|
|
|
|
|
|
|
// Allocate on GPU memory
|
|
|
|
|
if (!cudnn_workspace_ptr) {
|
|
|
|
|
cudnn_workspace =
|
|
|
|
|
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(workspace_size_in_bytes)}),
|
|
|
|
|
dev_ctx);
|
|
|
|
|
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
|
|
|
|
|
}
|
|
|
|
|
// ------------------- cudnn conv forward ---------------------
|
|
|
|
|
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
|
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
|
auto cudnn_func = [&](void* cudnn_workspace) {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
|
|
|
|
|
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
|
|
|
|
|
cudnn_filter_desc, filter_data + i * group_offset_filter,
|
|
|
|
|
cudnn_conv_desc, algo, cudnn_workspace_ptr, workspace_size_in_bytes,
|
|
|
|
|
cudnn_conv_desc, algo, cudnn_workspace, workspace_size_in_bytes,
|
|
|
|
|
&beta, cudnn_output_desc, output_data + i * group_offset_out));
|
|
|
|
|
};
|
|
|
|
|
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -365,20 +353,10 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
workspace_size_limit = max_user_size * 1024 * 1024;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Tensor cudnn_workspace;
|
|
|
|
|
void* cudnn_workspace_ptr = nullptr;
|
|
|
|
|
if ((input_data || filter_data) && exhaustive_search) {
|
|
|
|
|
cudnn_workspace =
|
|
|
|
|
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(workspace_size_limit)}),
|
|
|
|
|
dev_ctx);
|
|
|
|
|
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto x_dims = framework::vectorize(input->dims());
|
|
|
|
|
auto f_dims = framework::vectorize(filter->dims());
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
|
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
if (exhaustive_search) {
|
|
|
|
@ -396,22 +374,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
->GetMutable<
|
|
|
|
|
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
data_algo = data_algo_cache->GetAlgorithm(
|
|
|
|
|
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|
std::array<cudnnConvolutionBwdDataAlgoPerf_t,
|
|
|
|
|
kNUM_CUDNN_BWD_DATA_ALGS>
|
|
|
|
|
data_perf_stat;
|
|
|
|
|
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::
|
|
|
|
|
auto cudnn_find_bd_data_func = [&](void* cudnn_workspace) {
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::
|
|
|
|
|
cudnnFindConvolutionBackwardDataAlgorithmEx(
|
|
|
|
|
handle, cudnn_filter_desc, filter_data,
|
|
|
|
|
cudnn_output_grad_desc, output_grad_data,
|
|
|
|
|
cudnn_conv_desc, cudnn_input_desc,
|
|
|
|
|
input_grad_data, kNUM_CUDNN_BWD_DATA_ALGS,
|
|
|
|
|
&returned_algo_count, data_perf_stat.data(),
|
|
|
|
|
cudnn_workspace_ptr, workspace_size_limit));
|
|
|
|
|
cudnn_conv_desc, cudnn_input_desc, input_grad_data,
|
|
|
|
|
kNUM_CUDNN_BWD_DATA_ALGS, &returned_algo_count,
|
|
|
|
|
data_perf_stat.data(), cudnn_workspace,
|
|
|
|
|
workspace_size_limit));
|
|
|
|
|
};
|
|
|
|
|
workspace_handle.RunFunc(cudnn_find_bd_data_func,
|
|
|
|
|
workspace_size_limit);
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "Perf result: (algo: stat, time, memory)";
|
|
|
|
|
for (int i = 0; i < returned_algo_count; ++i) {
|
|
|
|
@ -462,23 +443,25 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
->GetMutable<
|
|
|
|
|
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
filter_algo = f_algo_cache->GetAlgorithm(
|
|
|
|
|
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
|
|
|
|
|
int returned_algo_count;
|
|
|
|
|
std::array<cudnnConvolutionBwdFilterAlgoPerf_t,
|
|
|
|
|
kNUM_CUDNN_BWD_FILTER_ALGS>
|
|
|
|
|
filter_perf_stat;
|
|
|
|
|
|
|
|
|
|
auto cudnn_find_bd_f_func = [&](void* cudnn_workspace) {
|
|
|
|
|
CUDNN_ENFORCE(
|
|
|
|
|
platform::dynload::
|
|
|
|
|
cudnnFindConvolutionBackwardFilterAlgorithmEx(
|
|
|
|
|
handle, cudnn_input_desc, input_data,
|
|
|
|
|
cudnn_output_grad_desc, output_grad_data,
|
|
|
|
|
cudnn_conv_desc, cudnn_filter_desc, filter_grad_data,
|
|
|
|
|
kNUM_CUDNN_BWD_FILTER_ALGS, &returned_algo_count,
|
|
|
|
|
filter_perf_stat.data(), cudnn_workspace_ptr,
|
|
|
|
|
workspace_size_limit));
|
|
|
|
|
cudnn_conv_desc, cudnn_filter_desc,
|
|
|
|
|
filter_grad_data, kNUM_CUDNN_BWD_FILTER_ALGS,
|
|
|
|
|
&returned_algo_count, filter_perf_stat.data(),
|
|
|
|
|
cudnn_workspace, workspace_size_limit));
|
|
|
|
|
};
|
|
|
|
|
workspace_handle.RunFunc(cudnn_find_bd_f_func,
|
|
|
|
|
workspace_size_limit);
|
|
|
|
|
return filter_perf_stat[0].algo;
|
|
|
|
|
});
|
|
|
|
|
VLOG(3) << "cuDNN backward filter algo " << filter_algo;
|
|
|
|
@ -499,16 +482,6 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
workspace_size_in_bytes = std::max(workspace_size_in_bytes, tmp_size);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn conv workspace ---------------------
|
|
|
|
|
if (!cudnn_workspace_ptr) {
|
|
|
|
|
cudnn_workspace =
|
|
|
|
|
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
|
|
|
|
|
framework::make_ddim(
|
|
|
|
|
{static_cast<int64_t>(workspace_size_in_bytes)}),
|
|
|
|
|
dev_ctx);
|
|
|
|
|
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn conv backward data ---------------------
|
|
|
|
|
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
|
|
|
|
|
if (input_grad) {
|
|
|
|
@ -516,12 +489,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Because beta is zero, it is unnecessary to reset input_grad.
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
|
auto cudnn_func = [&](void* cudnn_workspace) {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
|
|
|
|
|
handle, &alpha, cudnn_filter_desc,
|
|
|
|
|
filter_data + i * group_offset_filter, cudnn_output_grad_desc,
|
|
|
|
|
output_grad_data + i * group_offset_out, cudnn_conv_desc, data_algo,
|
|
|
|
|
cudnn_workspace_ptr, workspace_size_in_bytes, &beta,
|
|
|
|
|
output_grad_data + i * group_offset_out, cudnn_conv_desc,
|
|
|
|
|
data_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
|
|
|
|
|
cudnn_input_desc, input_grad_data + i * group_offset_in));
|
|
|
|
|
};
|
|
|
|
|
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// ------------------- cudnn conv backward filter ---------------------
|
|
|
|
@ -529,12 +505,15 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
T* filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
// Because beta is zero, it is unnecessary to reset filter_grad.
|
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
|
auto cudnn_func = [&](void* cudnn_workspace) {
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
|
|
|
|
|
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
|
|
|
|
|
cudnn_output_grad_desc, output_grad_data + i * group_offset_out,
|
|
|
|
|
cudnn_conv_desc, filter_algo, cudnn_workspace_ptr,
|
|
|
|
|
workspace_size_in_bytes, &beta, cudnn_filter_desc,
|
|
|
|
|
filter_grad_data + i * group_offset_filter));
|
|
|
|
|
handle, &alpha, cudnn_input_desc,
|
|
|
|
|
input_data + i * group_offset_in, cudnn_output_grad_desc,
|
|
|
|
|
output_grad_data + i * group_offset_out, cudnn_conv_desc,
|
|
|
|
|
filter_algo, cudnn_workspace, workspace_size_in_bytes, &beta,
|
|
|
|
|
cudnn_filter_desc, filter_grad_data + i * group_offset_filter));
|
|
|
|
|
};
|
|
|
|
|
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|