|
|
|
@ -20,6 +20,11 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
|
|
|
|
|
|
DEFINE_bool(cudnn_algo_use_autotune, true,
|
|
|
|
|
"Whether allow using an autotuning algorithm for convolution "
|
|
|
|
|
"operator. The autotuning algorithm may be non-deterministic. If "
|
|
|
|
|
"false, the algorithm is deterministic.");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -267,17 +272,23 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
handle, cudnn_filter_desc,
|
|
|
|
|
// dyDesc: Handle to the previously initialized input differential
|
|
|
|
|
// tensor descriptor.
|
|
|
|
|
cudnn_output_grad_desc, cudnn_conv_desc,
|
|
|
|
|
// dxDesc: Handle to the previously initialized output tensor
|
|
|
|
|
// descriptor.
|
|
|
|
|
cudnn_input_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &data_algo));
|
|
|
|
|
if (FLAGS_cudnn_algo_use_autotune) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
|
|
|
|
|
handle, cudnn_filter_desc,
|
|
|
|
|
// dyDesc: Handle to the previously initialized input
|
|
|
|
|
// differential
|
|
|
|
|
// tensor descriptor.
|
|
|
|
|
cudnn_output_grad_desc, cudnn_conv_desc,
|
|
|
|
|
// dxDesc: Handle to the previously initialized output tensor
|
|
|
|
|
// descriptor.
|
|
|
|
|
cudnn_input_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &data_algo));
|
|
|
|
|
} else {
|
|
|
|
|
data_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
|
|
|
handle, cudnn_filter_desc, cudnn_output_grad_desc,
|
|
|
|
@ -286,12 +297,16 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
|
|
|
|
|
cudnn_filter_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &filter_algo));
|
|
|
|
|
if (FLAGS_cudnn_algo_use_autotune) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
|
|
|
|
|
handle, cudnn_input_desc, cudnn_output_grad_desc,
|
|
|
|
|
cudnn_conv_desc, cudnn_filter_desc,
|
|
|
|
|
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
|
|
|
|
|
workspace_size_limit, &filter_algo));
|
|
|
|
|
} else {
|
|
|
|
|
filter_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
platform::dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
|
|
|