@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_int64(cudnn_exhaustive_search_times);
DECLARE_int64(cudnn_exhaustive_search_times);
@ -44,6 +47,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NOT_NULL(bias, "The bias should not be null.");
PADDLE_ENFORCE_NOT_NULL(bias, "The bias should not be null.");
auto* residual = ctx.Input<Tensor>("ResidualData");
auto* residual = ctx.Input<Tensor>("ResidualData");
auto* output = ctx.Output<Tensor>("Output");
auto* output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
@ -55,11 +59,96 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
bool exhaustive_search =
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
const T* input_data = input->data<T>();
// const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
const T* filter_data = filter->data<T>();
const T* bias_data = bias->data<T>();
const T* bias_data = bias->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// T* output_data = output->mutable_data<T>(ctx.GetPlace());
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
const std::string data_format = ctx.Attr<std::string>("data_format");
Tensor transformed_input_channel(input->type());
Tensor transformed_output(output->type());
T* output_data = nullptr;
transformed_input_channel = *input;
transformed_output = *output;
output_data = transformed_output.data<T>();
const T* residual_data = residual ? residual->data<T>() : output_data;
const T* residual_data = residual ? residual->data<T>() : output_data;
// update padding and dilation
auto in_dims = transformed_input_channel.dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims;
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = framework::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = math::IsSymmetricPadding(paddings, data_dim);
Tensor transformed_input;
std::vector<int> padding_common(data_dim, 0);
if (!is_sys_pad) {
std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_input_channel.dims()[0];
new_input_shape_vec[1] = transformed_input_channel.dims()[1];
std::vector<int> input_pad(transformed_input_channel.dims().size() * 2,
0);
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] =
transformed_input_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
framework::DDim new_input_shape(
framework::make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape);
auto& dev_ctx =
ctx.template device_context<paddle::platform::CUDADeviceContext>();
transformed_input =
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
new_input_shape, dev_ctx);
const int rank = transformed_input_channel.dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
case 5: {
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
ctx, input_pad, transformed_input_channel, pad_value,
&transformed_input);
} break;
default:
PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions.");
}
} else {
transformed_input = transformed_input_channel;
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
const T* input_data = transformed_input.data<T>();
// ------------------- cudnn descriptors ---------------------
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor input_desc;
@ -74,18 +163,19 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
}
}
cudnnConvolutionDescriptor_t cudnn_conv_desc =
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings , strides, dilations);
conv_desc.descriptor<T>(padding_common , strides, dilations);
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups));
cudnn_conv_desc, groups));
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(input-> dims()));
layout, framework::vectorize<int>(transformed_input. dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize<int>(output-> dims()));
layout, framework::vectorize<int>(transformed_output. dims()));
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize<int>(filter->dims()));
layout, framework::vectorize<int>(filter->dims()));
// Now only support NCHW
// Now only support NCHW
std::vector<int> bias_dim = {1, static_cast<int>(output->dims()[1]), 1, 1};
std::vector<int> bias_dim = {
1, static_cast<int>(transformed_output.dims()[1]), 1, 1};
cudnnTensorDescriptor_t cudnn_bias_desc =
cudnnTensorDescriptor_t cudnn_bias_desc =
bias_desc.descriptor<T>(layout, bias_dim);
bias_desc.descriptor<T>(layout, bias_dim);
cudnnActivationDescriptor_t cudnn_act_desc =
cudnnActivationDescriptor_t cudnn_act_desc =
@ -109,7 +199,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
auto x_dims = framework::vectorize(input-> dims());
auto x_dims = framework::vectorize(transformed_input. dims());
auto f_dims = framework::vectorize(filter->dims());
auto f_dims = framework::vectorize(filter->dims());
if (!exhaustive_search) {
if (!exhaustive_search) {
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(