From 00efd1d8a91d976224733035c6669adf8933be66 Mon Sep 17 00:00:00 2001 From: chengjuntao <806512756@qq.com> Date: Tue, 17 Sep 2019 19:52:21 +0800 Subject: [PATCH] add deformable conv v1 op and cpu version of deformable conv v2 (#18500) * add deformable conv v1 op, test=develop --- cmake/operators.cmake | 2 +- paddle/fluid/API.spec | 2 +- paddle/fluid/operators/CMakeLists.txt | 4 +- .../operators/deformable_conv_filter.cu.h | 37 ++ paddle/fluid/operators/deformable_conv_func.h | 149 +++++ paddle/fluid/operators/deformable_conv_op.cc | 10 +- paddle/fluid/operators/deformable_conv_op.cu | 1 + paddle/fluid/operators/deformable_conv_op.h | 613 ++++++++++++++++++ .../fluid/operators/deformable_conv_v1_op.cc | 272 ++++++++ .../fluid/operators/deformable_conv_v1_op.cu | 609 +++++++++++++++++ .../fluid/operators/deformable_conv_v1_op.h | 564 ++++++++++++++++ python/paddle/fluid/layers/nn.py | 87 ++- .../unittests/test_deformable_conv_op.py | 53 +- .../unittests/test_deformable_conv_v1_op.py | 240 +++++++ .../fluid/tests/unittests/test_layers.py | 74 ++- 15 files changed, 2627 insertions(+), 90 deletions(-) mode change 100755 => 100644 paddle/fluid/API.spec create mode 100644 paddle/fluid/operators/deformable_conv_filter.cu.h create mode 100644 paddle/fluid/operators/deformable_conv_func.h create mode 100644 paddle/fluid/operators/deformable_conv_op.h create mode 100644 paddle/fluid/operators/deformable_conv_v1_op.cc create mode 100644 paddle/fluid/operators/deformable_conv_v1_op.cu create mode 100644 paddle/fluid/operators/deformable_conv_v1_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index f43d284ad0..92dc614e09 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -110,7 +110,7 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op") +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec old mode 100755 new mode 100644 index a0a59eefc5..5252ca5d10 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -285,7 +285,7 @@ paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords= paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'c03490ffaa1b78258747157c313db4cd')) paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', 'b1e1487760295e1ff55307b880a99e18')) paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'fa2f457a81714430c5677c2d68744728')) -paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e')) +paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'modulated', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, True, None)), ('document', '335193ac57d41d7199f8d26d30c069b1')) paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6')) paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35')) paddle.fluid.layers.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', 'b6ea7d4ddeacae85e37d1e47d5262948')) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d2d0f6248e..f99cbc8762 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -55,7 +55,7 @@ if (NOT WITH_MKL) endif() register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op - sync_batch_norm_op deformable_conv_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) + sync_batch_norm_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) if (WITH_GPU) # warpctc_op needs cudnn 7 above @@ -73,8 +73,6 @@ if (WITH_GPU) op_library(sync_batch_norm_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") endif() - op_library(deformable_conv_op) - file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(deformable_conv);\n") else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() diff --git a/paddle/fluid/operators/deformable_conv_filter.cu.h b/paddle/fluid/operators/deformable_conv_filter.cu.h new file mode 100644 index 0000000000..f466d1803f --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_filter.cu.h @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Part of the following code in this file refs to +// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu +// +// Copyright (c) 2017 Microsoft +// Licensed under The Apache-2.0 License [see LICENSE for details] +// \file deformable_psroi_pooling.cu +// \brief +// \author Yi Li, Guodong Zhang, Jifeng Dai + +#pragma once +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +template +__global__ void FilterGradAddupCUDAKernel(const int nthreads, const int n, + const int height, const int width, + const T* dweight_3d, T* filter_grad) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + filter_grad[i] = filter_grad[i] + dweight_3d[i]; + } +} diff --git a/paddle/fluid/operators/deformable_conv_func.h b/paddle/fluid/operators/deformable_conv_func.h new file mode 100644 index 0000000000..ba1c504430 --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_func.h @@ -0,0 +1,149 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Part of the following code in this file refs to +// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu +// +// Copyright (c) 2017 Microsoft +// Licensed under The Apache-2.0 License [see LICENSE for details] +// \file deformable_psroi_pooling.cu +// \brief +// \author Yi Li, Guodong Zhang, Jifeng Dai + +#pragma once +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/hostdevice.h" + +template +HOSTDEVICE T DmcnGetGradientWeight(T argmax_h, T argmax_w, const int h, + const int w, const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + weight = (h == argmax_h_low && w == argmax_w_low) + ? (h + 1 - argmax_h) * (w + 1 - argmax_w) + : weight; + weight = (h == argmax_h_low && w == argmax_w_high) + ? (h + 1 - argmax_h) * (argmax_w + 1 - w) + : weight; + weight = (h == argmax_h_high && w == argmax_w_low) + ? (argmax_h + 1 - h) * (w + 1 - argmax_w) + : weight; + weight = (h == argmax_h_high && w == argmax_w_high) + ? (argmax_h + 1 - h) * (argmax_w + 1 - w) + : weight; + + return weight; +} + +template +HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height, + const int width, const T* im_data, + const int data_width, const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + T weight = 0; + + if (bp_dir == 0) { + weight += (argmax_h_low >= 0 && argmax_w_low >= 0) + ? -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low] + : 0; + + weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1) + ? -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high] + : 0; + + weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0) + ? (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low] + : 0; + weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + ? (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high] + : 0; + } else if (bp_dir == 1) { + weight += (argmax_h_low >= 0 && argmax_w_low >= 0) + ? -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low] + : 0; + weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1) + ? (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high] + : 0; + weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0) + ? -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low] + : 0; + weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + ? (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high] + : 0; + } + + return weight; +} + +template +HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, const int data_width, + const int height, const int width, T h, T w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh; + T hw = 1 - lw; + + T v1 = + (h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0; + T v2 = (h_low >= 0 && w_high <= width - 1) + ? bottom_data[h_low * data_width + w_high] + : 0; + T v3 = (h_high <= height - 1 && w_low >= 0) + ? bottom_data[h_high * data_width + w_low] + : 0; + T v4 = (h_high <= height - 1 && w_high <= width - 1) + ? bottom_data[h_high * data_width + w_high] + : 0; + + T w1 = hh * hw; + T w2 = hh * lw; + T w3 = lh * hw; + T w4 = lh * lw; + + return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; +} diff --git a/paddle/fluid/operators/deformable_conv_op.cc b/paddle/fluid/operators/deformable_conv_op.cc index 92a93dc747..01cbec5633 100644 --- a/paddle/fluid/operators/deformable_conv_op.cc +++ b/paddle/fluid/operators/deformable_conv_op.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/deformable_conv_op.h" +#include #include "paddle/fluid/operators/conv_op.h" namespace paddle { @@ -197,7 +199,6 @@ class DeformableConvOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]), deformable_groups, "mask filter must divide deformable group size."); - ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); } @@ -274,5 +275,10 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp, ops::DeformableConvOpMaker, ops::DeformableConvGradOpDescMaker); - REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp); + +REGISTER_OP_CPU_KERNEL(deformable_conv, ops::DeformableConvCPUKernel, + ops::DeformableConvCPUKernel); +REGISTER_OP_CPU_KERNEL(deformable_conv_grad, + ops::DeformableConvGradCPUKernel, + ops::DeformableConvGradCPUKernel); diff --git a/paddle/fluid/operators/deformable_conv_op.cu b/paddle/fluid/operators/deformable_conv_op.cu index cbb9bed90c..0a771627e0 100644 --- a/paddle/fluid/operators/deformable_conv_op.cu +++ b/paddle/fluid/operators/deformable_conv_op.cu @@ -24,6 +24,7 @@ #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/deformable_conv_op.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cuda_primitives.h" diff --git a/paddle/fluid/operators/deformable_conv_op.h b/paddle/fluid/operators/deformable_conv_op.h new file mode 100644 index 0000000000..33a97bf48b --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_op.h @@ -0,0 +1,613 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Part of the following code in this file refs to +// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu +// +// Copyright (c) 2017 Microsoft +// Licensed under The Apache-2.0 License [see LICENSE for details] +// \file deformable_psroi_pooling.cu +// \brief +// \author Yi Li, Guodong Zhang, Jifeng Dai + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/deformable_conv_func.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using CPUDeviceContext = platform::CPUDeviceContext; + +template +void ModulatedDeformableCol2imCPUKernel( + const int num_kernels, const T* data_col, const T* data_offset, + const T* data_mask, const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int deformable_group, const int height_col, + const int width_col, T* grad_im) { + for (size_t thread = 0; thread < num_kernels; thread++) { + const int j = (thread / width_col / height_col / batch_size) % kernel_w; + const int i = + (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + thread / width_col / height_col / batch_size / kernel_w / kernel_h; + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = thread % width_col; + int h_out = (thread / width_col) % height_col; + int b = (thread / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const T* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[thread] * mask; + const int cur_h = static_cast(cur_inv_h_data); + const int cur_w = static_cast(cur_inv_w_data); + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = + DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, + cur_w + dx, height, width); + + *(grad_im + cur_bottom_grad_pos) = + *(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad; + } + } + } + } +} + +template +static inline void ModulatedDeformableCol2imCPU( + const platform::CPUDeviceContext& ctx, const T* data_col, + const T* data_offset, const T* data_mask, + const std::vector im_shape, const std::vector col_shape, + const std::vector kernel_shape, const std::vector pad, + const std::vector stride, const std::vector dilation, + const int deformable_group, T* grad_im) { + int channel_per_deformable_group = im_shape[0] / deformable_group; + int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + ModulatedDeformableCol2imCPUKernel( + num_kernels, data_col, data_offset, data_mask, im_shape[0], im_shape[1], + im_shape[2], kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0], + stride[1], dilation[0], dilation[1], channel_per_deformable_group, + col_shape[1], deformable_group, col_shape[2], col_shape[3], grad_im); +} + +template +void ModulatedDeformableCol2imCoordCPUKernel( + const int num_kernels, const T* data_col, const T* data_im, + const T* data_offset, const T* data_mask, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int offset_channels, const int deformable_group, const int height_col, + const int width_col, T* grad_offset, T* grad_mask) { + for (size_t i = 0; i < num_kernels; i++) { + T val = 0, mval = 0; + const int w = i % width_col; + const int h = (i / width_col) % height_col; + const int c = (i / width_col / height_col) % offset_channels; + const int b = (i / width_col / height_col) / offset_channels; + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T* data_col_ptr = data_col + + deformable_group_index * + channel_per_deformable_group * batch_size * + width_col * height_col; + const T* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / + kernel_w * height * width; + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const T* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const int data_mask_hw_ptr = + (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * + DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width, + height, width, inv_h, inv_w); + } + const T weight = DmcnGetCoordinateWeight( + inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + grad_offset[i] = val; + if (offset_c % 2 == 0) + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * + kernel_w + + offset_c / 2) * + height_col + + h) * + width_col + + w] = mval; + } +} + +template +static inline void ModulatedDeformableCol2imCoordCPU( + const platform::CPUDeviceContext& ctx, const T* data_col, const T* data_im, + const T* data_offset, const T* data_mask, + const std::vector im_shape, const std::vector col_shape, + const std::vector kernel_shape, const std::vector paddings, + const std::vector strides, const std::vector dilations, + const int deformable_groups, T* grad_offset, T* grad_mask) { + int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int channel_per_deformable_group = col_shape[0] / deformable_groups; + + ModulatedDeformableCol2imCoordCPUKernel( + num_kernels, data_col, data_im, data_offset, data_mask, im_shape[0], + im_shape[1], im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0], + paddings[1], strides[0], strides[1], dilations[0], dilations[1], + channel_per_deformable_group, col_shape[1], + 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, + deformable_groups, col_shape[2], col_shape[3], grad_offset, grad_mask); +} + +template +void ModulatedDeformableIm2colCPUKernel( + const int num_kernels, const T* data_im, const T* data_offset, + const T* data_mask, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T* data_col) { + for (size_t i = 0; i < num_kernels; i++) { + const int w_col = i % width_col; + const int h_col = (i / width_col) % height_col; + const int b_col = (i / width_col) / height_col % batch_size; + const int c_im = (i / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = + data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const T* data_mask_ptr = + data_mask + + (b_col * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + val = + DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +static inline void ModulatedDeformableIm2colCPU( + const platform::CPUDeviceContext& ctx, const T* data_im, + const T* data_offset, const T* data_mask, + const std::vector im_shape, const std::vector col_shape, + const std::vector filter_shape, const std::vector paddings, + const std::vector strides, const std::vector dilations, + const int deformable_groups, T* data_col) { + int channel_per_deformable_group = im_shape[0] / deformable_groups; + int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + // get outputs of im2col with offset by bilinear interpolation + ModulatedDeformableIm2colCPUKernel( + num_kernels, data_im, data_offset, data_mask, im_shape[1], im_shape[2], + filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0], + strides[1], dilations[0], dilations[1], channel_per_deformable_group, + col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3], + data_col); +} + +template +void FilterGradAddupCPUKernel(const int nthreads, const int n, const int height, + const int width, const T* dweight_3d, + T* filter_grad) { + for (size_t i = 0; i < nthreads; i++) { + filter_grad[i] = filter_grad[i] + dweight_3d[i]; + } +} + +template +class DeformableConvCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* offset = ctx.Input("Offset"); + auto* mask = ctx.Input("Mask"); + Tensor filter = *ctx.Input("Filter"); + Tensor* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + const int groups = ctx.Attr("groups"); + const int deformable_groups = ctx.Attr("deformable_groups"); + const int im2col_step = ctx.Attr("im2col_step"); + const std::vector strides = ctx.Attr>("strides"); + const std::vector paddings = ctx.Attr>("paddings"); + const std::vector dilations = ctx.Attr>("dilations"); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec(framework::vectorize(output->dims())); + + // col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w} + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = + input->dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_buffer_shape_vec)); + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + framework::DDim output_shape(framework::make_ddim(output_buffer_shape_vec)); + Tensor col_buffer; + Tensor output_buffer; + col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); + output_buffer = + ctx.AllocateTmpTensor(output_shape, dev_ctx); + int64_t M = output_shape_vec[1] / groups; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = + input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups; + + Tensor weight_3d; + weight_3d.ShareDataWith(filter).Resize( + framework::make_ddim({groups, M, K})); + Tensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer) + .Resize(framework::make_ddim({groups, K, N})); + Tensor output_4d; + output_4d.ShareDataWith(output_buffer) + .Resize(framework::make_ddim({batch_size / im2col_step, groups, M, N})); + output_4d.mutable_data(ctx.GetPlace()); + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + std::vector input_shape_vec = framework::vectorize(input_shape); + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset->numel() / offset->dims()[0]; + int input_mask_dim = mask->numel() / mask->dims()[0]; + auto blas = math::GetBlas(dev_ctx); + const T* input_ptr = input->data(); + const T* offset_ptr = offset->data(); + const T* mask_ptr = mask->data(); + col_buffer.mutable_data(ctx.GetPlace()); + T* col_buffer_ptr = col_buffer.data(); + for (int i = 0; i < batch_size / im2col_step; ++i) { + ModulatedDeformableIm2colCPU( + dev_ctx, input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations, + deformable_groups, col_buffer_ptr); + Tensor output_3d = output_4d.Slice(i, i + 1).Resize( + framework::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); + // get the product of pixel and weight + for (int g = 0; g < groups; ++g) { + Tensor weight_3d_slice = + weight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + weight_3d.dims(), 1, weight_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + Tensor output_3d_slice = + output_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + output_3d.dims(), 1, output_3d.dims().size())); + blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0), + &output_3d_slice, T(0.0)); + } + } + output->ShareDataWith(output_buffer) + .Resize(framework::make_ddim(output_shape_vec)); + } +}; + +template +class DeformableConvGradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); + Tensor* mask_grad = ctx.Output(framework::GradVarName("Mask")); + + const Tensor* input = ctx.Input("Input"); + Tensor offset = *ctx.Input("Offset"); + Tensor mask = *ctx.Input("Mask"); + Tensor filter = *ctx.Input("Filter"); + if (!input_grad && !filter_grad && !offset_grad && !mask_grad) return; + + int groups = ctx.Attr("groups"); + int deformable_groups = ctx.Attr("deformable_groups"); + int im2col_step = ctx.Attr("im2col_step"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + auto& dev_ctx = ctx.template device_context(); + const int batch_size = static_cast(input->dims()[0]); + + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + std::vector input_shape_vec = framework::vectorize(input_shape); + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec( + framework::vectorize(output_grad->dims())); + + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = + input->dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_buffer_shape_vec)); + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + framework::DDim output_shape(framework::make_ddim(output_buffer_shape_vec)); + Tensor col_buffer; + Tensor output_buffer; + col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); + output_buffer = + ctx.AllocateTmpTensor(output_shape, dev_ctx); + + output_buffer.ShareDataWith(*output_grad); + + int64_t M = + input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = output_shape_vec[1] / groups; + + framework::DDim weight_3d_shape = {groups, K, M}; + framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, + N}; + framework::DDim col_buffer_3d_shape = {groups, M, N}; + framework::DDim filter_grad_shape = {groups, K, M}; + + Tensor weight_3d; + weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); + Tensor out_grad_4d; + out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); + Tensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); + + math::SetConstant set_zero; + auto blas = math::GetBlas(dev_ctx); + + col_buffer.mutable_data(ctx.GetPlace()); + col_buffer_3d.mutable_data(ctx.GetPlace()); + out_grad_4d.mutable_data(ctx.GetPlace()); + + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + int input_mask_dim = mask.numel() / mask.dims()[0]; + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + filter_grad->Resize(filter_grad_shape); + set_zero(dev_ctx, filter_grad, static_cast(0)); + } + + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, input_grad, static_cast(0)); + } + + if (offset_grad && mask_grad) { + offset_grad->mutable_data(ctx.GetPlace()); + mask_grad->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, offset_grad, static_cast(0)); + set_zero(dev_ctx, mask_grad, static_cast(0)); + } + + for (int i = 0; i < batch_size / im2col_step; ++i) { + Tensor out_grad_3d = + out_grad_4d.Slice(i, i + 1).Resize(framework::slice_ddim( + out_grad_4d.dims(), 1, out_grad_4d.dims().size())); + for (int g = 0; g < groups; ++g) { + Tensor weight_3d_slice = + weight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + weight_3d.dims(), 1, weight_3d.dims().size())); + Tensor out_grad_3d_slice = + out_grad_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + + blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0), + &col_buffer_3d_slice, T(0.0)); + } + col_buffer.Resize(col_shape); + + T* col_buffer_ptr = col_buffer.data(); + const T* input_ptr = input->data(); + const T* offset_ptr = offset.data(); + const T* mask_ptr = mask.data(); + + if (mask_grad && offset_grad) { + T* offset_grad_ptr = offset_grad->data(); + T* mask_grad_ptr = mask_grad->data(); + // get grad of offset and mask + ModulatedDeformableCol2imCoordCPU( + ctx.template device_context(), col_buffer_ptr, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, + dilations, deformable_groups, + offset_grad_ptr + i * im2col_step * input_offset_dim, + mask_grad_ptr + i * im2col_step * input_mask_dim); + } + if (input_grad) { + T* input_grad_ptr = input_grad->data(); + // get grad of input + ModulatedDeformableCol2imCPU( + ctx.template device_context(), col_buffer_ptr, + offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, + dilations, deformable_groups, + input_grad_ptr + i * im2col_step * input_dim); + input_grad->Resize(input->dims()); + } + + ModulatedDeformableIm2colCPU( + ctx.template device_context(), + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_ptr + i * im2col_step * input_mask_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, dilations, + deformable_groups, col_buffer_ptr); + + col_buffer_3d.Resize(col_buffer_3d_shape); + + if (filter_grad) { + Tensor dweight_3d; + dweight_3d = ctx.AllocateTmpTensor( + filter_grad_shape, dev_ctx); + for (int g = 0; g < groups; ++g) { + Tensor out_grad_3d_slice = + out_grad_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + Tensor dweight_3d_slice = + dweight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + dweight_3d.dims(), 1, dweight_3d.dims().size())); + + blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true, + T(1.0), &dweight_3d_slice, T(0.0)); + } + // update grad of weights + FilterGradAddupCPUKernel(dweight_3d.numel(), groups, K, M, + dweight_3d.data(), filter_grad->data()); + } + } + if (filter_grad) { + filter_grad->Resize(filter.dims()); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cc b/paddle/fluid/operators/deformable_conv_v1_op.cc new file mode 100644 index 0000000000..6129e29655 --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_v1_op.cc @@ -0,0 +1,272 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/deformable_conv_v1_op.h" +#include +#include "paddle/fluid/operators/conv_op.h" + +namespace paddle { +namespace operators { +class DeformableConvV1OpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input", + "(Tensor) The input of deformable conv op. " + "The shape of input is " + "[N, channel_in, H, W]"); + AddInput("Offset", + "(Tensor) The input offset. " + "The shape of the offset is " + "[N, deformable_groups * kernel_w * kernel_h * 2, H, W"); + AddInput("Filter", + "(Tensor) The Input Filter " + "The shape of the wight is " + "[num_filters, channel_in, kernel_h, kernel_w."); + AddOutput("Output", + "(Tensor) The output. " + "The shape of the output tensor is " + "[N, num_filters, out_height, out_width]]."); + AddAttr>("strides", + "(vector default:{1, 1}), the " + "strides(h_stride, w_stride) of " + "convolution operator.") + .SetDefault({1, 1}); + AddAttr>("paddings", + "(vector default:{0,0}), the " + "paddings(h_pad, w_pad) of " + "convolution operator. ") + .SetDefault({0, 0}); + AddAttr>("dilations", + "(vector default:{1, 1}), the " + "dilations(h_dilation, w_dilation) of " + "convolution operator.") + .SetDefault({1, 1}); + AddAttr( + "groups", + "(int default:1), the groups number of the convolution operator. " + "According to grouped convolution in Alex Krizhevsky's Deep CNN paper: " + "when group=2, the first half of the filters is only connected to the " + "first half of the input channels, while the second half of the " + "filters " + "is only connected to the second half of the input channels.") + .SetDefault(1); + AddAttr("deformable_groups", + "(int default:1), the number of the deformable groups.") + .SetDefault(1); + AddAttr("im2col_step", + "im2col maximum number of image per computation") + .SetDefault(64); + AddComment(R"DOC( +**Deformable Convolution v1 Operator** + +Deformable Convolution is a new method based Convolution which feature has offset +in spatial location. + +1. Get offset of each pixel in feature map with convolution layers which number + of channels should be double of weight size. + +2. Add offset to pixel to get new location and the new value which are computed + directly through bilinear interpolation with four nearest pixel. + +3. Get the product of pixel and weight as result + +Compute 2-D deformable convolution on 4-D input. + +Given input image x, output feature map y, the deformable convolution operation can be expressed as follow: + +$$ +y(p) = \\sum_{k=1}^{K}{w_k * x(p + p_k + \\Delta p_k)} +$$ + +Where $$\\Delta p_k$$ is the learnable offset for the k-th location, respectively. + +Refer to 'https://arxiv.org/abs/1703.06211 ' + +Example: + Input: + Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Filter shape: $(C_{out}, C_{in}, H_f, W_f)$ + Offset shape: $(N, 2 * deformable_groups, * H_f * W_f, H_{out}, W_{out})$ + Output: + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + where $H_{out}, W_{out}$ must be equal to $H_{in}, W_{in}$ respectively. + Where +$$ + H_{out}= \frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\ + W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1 +$$ +)DOC"); + } +}; + +class DeformableConvV1Op : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, + "Input(Input) of DeformableConvOp " + "should not be null"); + PADDLE_ENFORCE_EQ(ctx->HasInput("Offset"), true, + "Input(Offset) of DeformableConvOp " + "should not be null"); + PADDLE_ENFORCE_EQ(ctx->HasInput("Filter"), true, + "Input(Filter) of DeformableConvOp " + "should not be null"); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Output"), true, + "Output(Output) of DeformableConvOp " + "should not be null."); + + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + auto offset_dims = ctx->GetInputDim("Offset"); + + std::vector strides = ctx->Attrs().Get>("strides"); + std::vector paddings = ctx->Attrs().Get>("paddings"); + std::vector dilations = + ctx->Attrs().Get>("dilations"); + int groups = ctx->Attrs().Get("groups"); + int deformable_groups = ctx->Attrs().Get("deformable_groups"); + int im2col_step = ctx->Attrs().Get("im2col_step"); + + PADDLE_ENFORCE_EQ(in_dims.size(), 4, + "Conv input should be 4-D tensor, get %u", + in_dims.size()); + PADDLE_ENFORCE_EQ( + in_dims.size(), filter_dims.size(), + "Conv input dimension and filter dimension should be the same."); + PADDLE_ENFORCE_EQ( + in_dims.size() - strides.size(), 2U, + "Conv input dimension and strides dimension should be consistent."); + PADDLE_ENFORCE_EQ(paddings.size(), strides.size(), + "Conv paddings dimension and Conv strides dimension " + "should be the same."); + + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, + "The number of input channels should be equal to filter " + "channels * groups."); + PADDLE_ENFORCE_EQ( + filter_dims[0] % groups, 0, + "The number of output channels should be divided by groups."); + PADDLE_ENFORCE_EQ(filter_dims[0] % deformable_groups, 0, + "The number of output channels should be " + "divided by deformable groups."); + + if (in_dims[0] > im2col_step) { + PADDLE_ENFORCE_EQ( + in_dims[0] % im2col_step, 0U, + "Input batchsize must be smaller than or divide im2col_step"); + } + + for (size_t i = 0; i < strides.size(); ++i) { + PADDLE_ENFORCE_GT(strides[i], 0U, "stride %d size incorrect", i); + } + for (size_t i = 0; i < dilations.size(); ++i) { + PADDLE_ENFORCE_GT(dilations[i], 0U, "dilation %d size incorrect", i); + } + + std::vector output_shape({in_dims[0], filter_dims[0]}); + for (size_t i = 0; i < strides.size(); ++i) { + output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], + dilations[i], paddings[i], + strides[i])); + } + PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U, + "output num_filter must divide deformable group size."); + PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2], + "output height must equal to offset map height."); + PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3], + "output width must equal to offset map width."); + PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U, + "offset filter must divide deformable group size."); + PADDLE_ENFORCE_EQ(offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]), + deformable_groups, + "offset filter must divide deformable group size."); + + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.device_context()); + } +}; + +class DeformableConvV1GradOpDescMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + + op->SetType("deformable_conv_v1_grad"); + op->SetInput("Input", Input("Input")); + op->SetInput("Filter", Input("Filter")); + op->SetInput("Offset", Input("Offset")); + op->SetInput(framework::GradVarName("Output"), OutputGrad("Output")); + + op->SetOutput(framework::GradVarName("Input"), InputGrad("Input")); + op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter")); + op->SetOutput(framework::GradVarName("Offset"), InputGrad("Offset")); + + op->SetAttrMap(Attrs()); + return op; + } +}; + +class DeformableConvV1GradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + auto in_dims = ctx->GetInputDim("Input"); + auto filter_dims = ctx->GetInputDim("Filter"); + auto offset_dims = ctx->GetInputDim("Offset"); + + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Output")), true, + "the gradient of output(Out) must not be null"); + if (ctx->HasOutput(framework::GradVarName("Input"))) { + ctx->SetOutputDim(framework::GradVarName("Input"), in_dims); + } + if (ctx->HasOutput(framework::GradVarName("Filter"))) { + ctx->SetOutputDim(framework::GradVarName("Filter"), filter_dims); + } + if (ctx->HasOutput(framework::GradVarName("Offset"))) { + ctx->SetOutputDim(framework::GradVarName("Offset"), offset_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("Input")->type(), + ctx.device_context()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(deformable_conv_v1, ops::DeformableConvV1Op, + ops::DeformableConvV1OpMaker, + ops::DeformableConvV1GradOpDescMaker); +REGISTER_OPERATOR(deformable_conv_v1_grad, ops::DeformableConvV1GradOp); + +REGISTER_OP_CPU_KERNEL(deformable_conv_v1, + ops::DeformableConvV1CPUKernel); +REGISTER_OP_CPU_KERNEL(deformable_conv_v1_grad, + ops::DeformableConvV1GradCPUKernel); diff --git a/paddle/fluid/operators/deformable_conv_v1_op.cu b/paddle/fluid/operators/deformable_conv_v1_op.cu new file mode 100644 index 0000000000..a865766f9a --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_v1_op.cu @@ -0,0 +1,609 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Part of the following code in this file refs to +// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu +// +// Copyright (c) 2017 Microsoft +// Licensed under The Apache-2.0 License [see LICENSE for details] +// \file deformable_psroi_pooling.cu +// \brief +// \author Yi Li, Guodong Zhang, Jifeng Dai + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/deformable_conv_filter.cu.h" +#include "paddle/fluid/operators/deformable_conv_func.h" +#include "paddle/fluid/operators/deformable_conv_v1_op.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using CUDADeviceContext = paddle::platform::CUDADeviceContext; + +static constexpr int kNumCUDAThread = 512; +static constexpr int kNumMaximumNumBlock = 4096; + +static inline int NumBlock(const int N) { + return std::min((N + kNumCUDAThread - 1) / kNumCUDAThread, + kNumMaximumNumBlock); +} + +template +__global__ void DeformableCol2imCUDAKernel( + const int nthreads, const T* data_col, const T* data_offset, + const int channels, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int deformable_group, const int height_col, const int width_col, + T* grad_im) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t thread = index; thread < nthreads; thread += offset) { + const int j = (thread / width_col / height_col / batch_size) % kernel_w; + const int i = + (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + thread / width_col / height_col / batch_size / kernel_w / kernel_h; + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = thread % width_col; + int h_out = (thread / width_col) % height_col; + int b = (thread / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[thread]; + const int cur_h = static_cast(cur_inv_h_data); + const int cur_w = static_cast(cur_inv_w_data); + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = + DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, + cur_w + dx, height, width); + + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +inline void DeformableCol2im(const platform::CUDADeviceContext& ctx, + const T* data_col, const T* data_offset, + const std::vector im_shape, + const std::vector col_shape, + const std::vector kernel_shape, + const std::vector pad, + const std::vector stride, + const std::vector dilation, + const int deformable_group, T* grad_im) { + int channel_per_deformable_group = im_shape[0] / deformable_group; + int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + int blocks = NumBlock(num_kernels); + int threads = kNumCUDAThread; + + DeformableCol2imCUDAKernel<<< + blocks, threads, 0, + reinterpret_cast(ctx).stream()>>>( + num_kernels, data_col, data_offset, im_shape[0], im_shape[1], im_shape[2], + kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], channel_per_deformable_group, col_shape[1], + deformable_group, col_shape[2], col_shape[3], grad_im); +} + +template +__global__ void DeformableCol2imCoordCUDAKernel( + const int nthreads, const T* data_col, const T* data_im, + const T* data_offset, const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, T* grad_offset) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + T val = 0, mval = 0; + const int w = i % width_col; + const int h = (i / width_col) % height_col; + const int c = (i / width_col / height_col) % offset_channels; + const int b = (i / width_col / height_col) / offset_channels; + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T* data_col_ptr = data_col + + deformable_group_index * + channel_per_deformable_group * batch_size * + width_col * height_col; + const T* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / + kernel_w * height * width; + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * + DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width, + height, width, inv_h, inv_w); + } + const T weight = DmcnGetCoordinateWeight( + inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + grad_offset[i] = val; + } +} + +template +inline void DeformableCol2imCoord( + const platform::CUDADeviceContext& ctx, const T* data_col, const T* data_im, + const T* data_offset, const std::vector im_shape, + const std::vector col_shape, + const std::vector kernel_shape, const std::vector paddings, + const std::vector strides, const std::vector dilations, + const int deformable_groups, T* grad_offset) { + int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int channel_per_deformable_group = col_shape[0] / deformable_groups; + int blocks = NumBlock(num_kernels); + int threads = kNumCUDAThread; + + DeformableCol2imCoordCUDAKernel<<< + blocks, threads, 0, + reinterpret_cast(ctx).stream()>>>( + num_kernels, data_col, data_im, data_offset, im_shape[0], im_shape[1], + im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0], paddings[1], + strides[0], strides[1], dilations[0], dilations[1], + channel_per_deformable_group, col_shape[1], + 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, + deformable_groups, col_shape[2], col_shape[3], grad_offset); +} + +template +__global__ void DeformableIm2colCUDAKernel( + const int nthreads, const T* data_im, const T* data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T* data_col) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t i = index; i < nthreads; i += offset) { + const int w_col = i % width_col; + const int h_col = (i / width_col) % height_col; + const int b_col = (i / width_col) / height_col % batch_size; + const int c_im = (i / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = + data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + val = + DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +inline void DeformableIm2col(const platform::CUDADeviceContext& ctx, + const T* data_im, const T* data_offset, + const std::vector im_shape, + const std::vector col_shape, + const std::vector filter_shape, + const std::vector paddings, + const std::vector strides, + const std::vector dilations, + const int deformable_groups, T* data_col) { + int channel_per_deformable_group = im_shape[0] / deformable_groups; + int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + int blocks = NumBlock(num_kernels); + int threads = kNumCUDAThread; + + // get outputs of im2col with offset by bilinear interpolation + DeformableIm2colCUDAKernel<<< + blocks, threads, 0, + reinterpret_cast(ctx).stream()>>>( + num_kernels, data_im, data_offset, im_shape[1], im_shape[2], + filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0], + strides[1], dilations[0], dilations[1], channel_per_deformable_group, + col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3], + data_col); +} + +template +class DeformableConvV1CUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* input = ctx.Input("Input"); + const Tensor offset = *ctx.Input("Offset"); + Tensor filter = *ctx.Input("Filter"); + Tensor* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + const int groups = ctx.Attr("groups"); + const int deformable_groups = ctx.Attr("deformable_groups"); + const int im2col_step = ctx.Attr("im2col_step"); + const std::vector strides = ctx.Attr>("strides"); + const std::vector paddings = ctx.Attr>("paddings"); + const std::vector dilations = ctx.Attr>("dilations"); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec(framework::vectorize(output->dims())); + + // col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w} + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = + input->dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_buffer_shape_vec)); + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + framework::DDim output_shape(framework::make_ddim(output_buffer_shape_vec)); + Tensor col_buffer; + Tensor output_buffer; + col_buffer = + ctx.AllocateTmpTensor(col_shape, dev_ctx); + output_buffer = + ctx.AllocateTmpTensor(output_shape, dev_ctx); + + int64_t M = output_shape_vec[1] / groups; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = + input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups; + + Tensor weight_3d; + weight_3d.ShareDataWith(filter).Resize( + framework::make_ddim({groups, M, K})); + Tensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer) + .Resize(framework::make_ddim({groups, K, N})); + Tensor output_4d; + output_4d.ShareDataWith(output_buffer) + .Resize(framework::make_ddim({batch_size / im2col_step, groups, M, N})); + output_4d.mutable_data(ctx.GetPlace()); + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + std::vector input_shape_vec = framework::vectorize(input_shape); + + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + + auto blas = math::GetBlas(dev_ctx); + + const T* input_ptr = input->data(); + const T* offset_ptr = offset.data(); + col_buffer.mutable_data(ctx.GetPlace()); + T* col_buffer_ptr = col_buffer.data(); + + for (int i = 0; i < batch_size / im2col_step; ++i) { + DeformableIm2col(dev_ctx, input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + input_shape_vec, col_buffer_shape_vec, filter_shape_vec, + paddings, strides, dilations, deformable_groups, + col_buffer_ptr); + + Tensor output_3d = output_4d.Slice(i, i + 1).Resize( + framework::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); + // get the product of pixel and weight + for (int g = 0; g < groups; ++g) { + Tensor weight_3d_slice = + weight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + weight_3d.dims(), 1, weight_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + Tensor output_3d_slice = + output_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + output_3d.dims(), 1, output_3d.dims().size())); + + blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0), + &output_3d_slice, T(0.0)); + } + } + output->ShareDataWith(output_buffer) + .Resize(framework::make_ddim(output_shape_vec)); + } +}; + +template +class DeformableConvV1GradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); + + const Tensor* input = ctx.Input("Input"); + Tensor offset = *ctx.Input("Offset"); + Tensor filter = *ctx.Input("Filter"); + if (!input_grad && !filter_grad && !offset_grad) return; + + int groups = ctx.Attr("groups"); + int deformable_groups = ctx.Attr("deformable_groups"); + int im2col_step = ctx.Attr("im2col_step"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + auto& dev_ctx = ctx.template device_context(); + const int batch_size = static_cast(input->dims()[0]); + + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + std::vector input_shape_vec = framework::vectorize(input_shape); + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec( + framework::vectorize(output_grad->dims())); + + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = + input->dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_buffer_shape_vec)); + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + framework::DDim output_shape(framework::make_ddim(output_buffer_shape_vec)); + Tensor col_buffer; + Tensor output_buffer; + col_buffer = + ctx.AllocateTmpTensor(col_shape, dev_ctx); + output_buffer = + ctx.AllocateTmpTensor(output_shape, dev_ctx); + + output_buffer.ShareDataWith(*output_grad); + + int64_t M = + input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = output_shape_vec[1] / groups; + + framework::DDim weight_3d_shape = {groups, K, M}; + framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, + N}; + framework::DDim col_buffer_3d_shape = {groups, M, N}; + framework::DDim filter_grad_shape = {groups, K, M}; + + Tensor weight_3d; + weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); + Tensor out_grad_4d; + out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); + Tensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); + + math::SetConstant set_zero; + auto blas = math::GetBlas(dev_ctx); + + col_buffer.mutable_data(ctx.GetPlace()); + col_buffer_3d.mutable_data(ctx.GetPlace()); + out_grad_4d.mutable_data(ctx.GetPlace()); + + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + filter_grad->Resize(filter_grad_shape); + set_zero(dev_ctx, filter_grad, static_cast(0)); + } + + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, input_grad, static_cast(0)); + } + + if (offset_grad) { + offset_grad->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, offset_grad, static_cast(0)); + } + + for (int i = 0; i < batch_size / im2col_step; ++i) { + Tensor out_grad_3d = + out_grad_4d.Slice(i, i + 1).Resize(framework::slice_ddim( + out_grad_4d.dims(), 1, out_grad_4d.dims().size())); + for (int g = 0; g < groups; ++g) { + Tensor weight_3d_slice = + weight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + weight_3d.dims(), 1, weight_3d.dims().size())); + Tensor out_grad_3d_slice = + out_grad_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + + blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0), + &col_buffer_3d_slice, T(0.0)); + } + col_buffer.Resize(col_shape); + + T* col_buffer_ptr = col_buffer.data(); + const T* input_ptr = input->data(); + const T* offset_ptr = offset.data(); + + if (offset_grad) { + T* offset_grad_ptr = offset_grad->data(); + // get grad of offset + DeformableCol2imCoord( + dev_ctx, col_buffer_ptr, input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, + dilations, deformable_groups, + offset_grad_ptr + i * im2col_step * input_offset_dim); + } + if (input_grad) { + T* input_grad_ptr = input_grad->data(); + // get grad of input + DeformableCol2im(dev_ctx, col_buffer_ptr, + offset_ptr + i * im2col_step * input_offset_dim, + input_shape_vec, col_buffer_shape_vec, + filter_shape_vec, paddings, strides, dilations, + deformable_groups, + input_grad_ptr + i * im2col_step * input_dim); + input_grad->Resize(input->dims()); + } + + DeformableIm2col(dev_ctx, input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + input_shape_vec, col_buffer_shape_vec, filter_shape_vec, + paddings, strides, dilations, deformable_groups, + col_buffer_ptr); + + col_buffer_3d.Resize(col_buffer_3d_shape); + + if (filter_grad) { + Tensor dweight_3d; + dweight_3d = ctx.AllocateTmpTensor( + filter_grad_shape, dev_ctx); + for (int g = 0; g < groups; ++g) { + Tensor out_grad_3d_slice = + out_grad_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + Tensor dweight_3d_slice = + dweight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + dweight_3d.dims(), 1, dweight_3d.dims().size())); + + blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true, + T(1.0), &dweight_3d_slice, T(0.0)); + } + FilterGradAddupCUDAKernel<<>>( + dweight_3d.numel(), groups, K, M, dweight_3d.data(), + filter_grad->data()); + } + } + if (filter_grad) { + filter_grad->Resize(filter.dims()); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL(deformable_conv_v1, + ops::DeformableConvV1CUDAKernel); +REGISTER_OP_CUDA_KERNEL(deformable_conv_v1_grad, + ops::DeformableConvV1GradCUDAKernel); diff --git a/paddle/fluid/operators/deformable_conv_v1_op.h b/paddle/fluid/operators/deformable_conv_v1_op.h new file mode 100644 index 0000000000..89dc10cfa3 --- /dev/null +++ b/paddle/fluid/operators/deformable_conv_v1_op.h @@ -0,0 +1,564 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Part of the following code in this file refs to +// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu +// +// Copyright (c) 2017 Microsoft +// Licensed under The Apache-2.0 License [see LICENSE for details] +// \file deformable_psroi_pooling.cu +// \brief +// \author Yi Li, Guodong Zhang, Jifeng Dai + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/deformable_conv_func.h" +#include "paddle/fluid/operators/deformable_conv_op.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using CPUDeviceContext = platform::CPUDeviceContext; + +template +void DeformableCol2imCPUKernel( + const int num_kernels, const T* data_col, const T* data_offset, + const int channels, const int height, const int width, const int kernel_h, + const int kernel_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int deformable_group, const int height_col, const int width_col, + T* grad_im) { + for (size_t thread = 0; thread < num_kernels; thread++) { + const int j = (thread / width_col / height_col / batch_size) % kernel_w; + const int i = + (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + thread / width_col / height_col / batch_size / kernel_w / kernel_h; + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = thread % width_col; + int h_out = (thread / width_col) % height_col; + int b = (thread / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T cur_inv_h_data = h_in + i * dilation_h + offset_h; + const T cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const T cur_top_grad = data_col[thread]; + const int cur_h = static_cast(cur_inv_h_data); + const int cur_w = static_cast(cur_inv_w_data); + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + T weight = + DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, + cur_w + dx, height, width); + + *(grad_im + cur_bottom_grad_pos) = + *(grad_im + cur_bottom_grad_pos) + weight * cur_top_grad; + } + } + } + } +} + +template +inline void DeformableCol2imCPU(const platform::CPUDeviceContext& ctx, + const T* data_col, const T* data_offset, + const std::vector im_shape, + const std::vector col_shape, + const std::vector kernel_shape, + const std::vector pad, + const std::vector stride, + const std::vector dilation, + const int deformable_group, T* grad_im) { + int channel_per_deformable_group = im_shape[0] / deformable_group; + int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + DeformableCol2imCPUKernel( + num_kernels, data_col, data_offset, im_shape[0], im_shape[1], im_shape[2], + kernel_shape[2], kernel_shape[3], pad[0], pad[1], stride[0], stride[1], + dilation[0], dilation[1], channel_per_deformable_group, col_shape[1], + deformable_group, col_shape[2], col_shape[3], grad_im); +} + +template +void DeformableCol2imCoordCPUKernel( + const int num_kernels, const T* data_col, const T* data_im, + const T* data_offset, const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, T* grad_offset) { + for (size_t i = 0; i < num_kernels; i++) { + T val = 0, mval = 0; + const int w = i % width_col; + const int h = (i / width_col) % height_col; + const int c = (i / width_col / height_col) % offset_channels; + const int b = (i / width_col / height_col) / offset_channels; + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const T* data_col_ptr = data_col + + deformable_group_index * + channel_per_deformable_group * batch_size * + width_col * height_col; + const T* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / + kernel_w * height * width; + const T* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * + width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T inv_h = h_in + i * dilation_h + offset_h; + T inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * + DmcnIm2colBilinear(data_im_ptr + cnt * height * width, width, + height, width, inv_h, inv_w); + } + const T weight = DmcnGetCoordinateWeight( + inv_h, inv_w, height, width, data_im_ptr + cnt * height * width, + width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + grad_offset[i] = val; + } +} + +template +inline void DeformableCol2imCoordCPU( + const platform::CPUDeviceContext& ctx, const T* data_col, const T* data_im, + const T* data_offset, const std::vector im_shape, + const std::vector col_shape, + const std::vector kernel_shape, const std::vector paddings, + const std::vector strides, const std::vector dilations, + const int deformable_groups, T* grad_offset) { + int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int channel_per_deformable_group = col_shape[0] / deformable_groups; + + DeformableCol2imCoordCPUKernel( + num_kernels, data_col, data_im, data_offset, im_shape[0], im_shape[1], + im_shape[2], kernel_shape[2], kernel_shape[3], paddings[0], paddings[1], + strides[0], strides[1], dilations[0], dilations[1], + channel_per_deformable_group, col_shape[1], + 2 * kernel_shape[2] * kernel_shape[3] * deformable_groups, + deformable_groups, col_shape[2], col_shape[3], grad_offset); +} + +template +void DeformableIm2colCPUKernel( + const int num_kernels, const T* data_im, const T* data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, const int batch_size, + const int num_channels, const int deformable_group, const int height_col, + const int width_col, T* data_col) { + for (size_t i = 0; i < num_kernels; i++) { + const int w_col = i % width_col; + const int h_col = (i / width_col) % height_col; + const int b_col = (i / width_col) / height_col % batch_size; + const int c_im = (i / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = + data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + T val = static_cast(0); + const T h_im = h_in + i * dilation_h + offset_h; + const T w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + val = + DmcnIm2colBilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +template +inline void DeformableIm2colCPU(const platform::CPUDeviceContext& ctx, + const T* data_im, const T* data_offset, + const std::vector im_shape, + const std::vector col_shape, + const std::vector filter_shape, + const std::vector paddings, + const std::vector strides, + const std::vector dilations, + const int deformable_groups, T* data_col) { + int channel_per_deformable_group = im_shape[0] / deformable_groups; + int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + + // get outputs of im2col with offset by bilinear interpolation + DeformableIm2colCPUKernel( + num_kernels, data_im, data_offset, im_shape[1], im_shape[2], + filter_shape[2], filter_shape[3], paddings[0], paddings[1], strides[0], + strides[1], dilations[0], dilations[1], channel_per_deformable_group, + col_shape[1], im_shape[0], deformable_groups, col_shape[2], col_shape[3], + data_col); +} + +template +class DeformableConvV1CPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* offset = ctx.Input("Offset"); + Tensor filter = *ctx.Input("Filter"); + Tensor* output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + + auto& dev_ctx = ctx.template device_context(); + + const int groups = ctx.Attr("groups"); + const int deformable_groups = ctx.Attr("deformable_groups"); + const int im2col_step = ctx.Attr("im2col_step"); + const std::vector strides = ctx.Attr>("strides"); + const std::vector paddings = ctx.Attr>("paddings"); + const std::vector dilations = ctx.Attr>("dilations"); + + const int batch_size = static_cast(input->dims()[0]); + + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec(framework::vectorize(output->dims())); + + // col_shape_vec: {c_i * k_h * k_w, im2col_step, o_h, o_w} + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = + input->dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_buffer_shape_vec)); + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + framework::DDim output_shape(framework::make_ddim(output_buffer_shape_vec)); + Tensor col_buffer; + Tensor output_buffer; + col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); + output_buffer = + ctx.AllocateTmpTensor(output_shape, dev_ctx); + int64_t M = output_shape_vec[1] / groups; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = + input->dims()[1] * filter_shape_vec[2] * filter_shape_vec[3] / groups; + + Tensor weight_3d; + weight_3d.ShareDataWith(filter).Resize( + framework::make_ddim({groups, M, K})); + Tensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer) + .Resize(framework::make_ddim({groups, K, N})); + Tensor output_4d; + output_4d.ShareDataWith(output_buffer) + .Resize(framework::make_ddim({batch_size / im2col_step, groups, M, N})); + output_4d.mutable_data(ctx.GetPlace()); + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + std::vector input_shape_vec = framework::vectorize(input_shape); + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset->numel() / offset->dims()[0]; + auto blas = math::GetBlas(dev_ctx); + const T* input_ptr = input->data(); + const T* offset_ptr = offset->data(); + col_buffer.mutable_data(ctx.GetPlace()); + T* col_buffer_ptr = col_buffer.data(); + for (int i = 0; i < batch_size / im2col_step; ++i) { + DeformableIm2colCPU(dev_ctx, input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + input_shape_vec, col_buffer_shape_vec, + filter_shape_vec, paddings, strides, dilations, + deformable_groups, col_buffer_ptr); + Tensor output_3d = output_4d.Slice(i, i + 1).Resize( + framework::slice_ddim(output_4d.dims(), 1, output_4d.dims().size())); + // get the product of pixel and weight + for (int g = 0; g < groups; ++g) { + Tensor weight_3d_slice = + weight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + weight_3d.dims(), 1, weight_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + Tensor output_3d_slice = + output_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + output_3d.dims(), 1, output_3d.dims().size())); + blas.MatMul(weight_3d_slice, false, col_buffer_3d_slice, false, T(1.0), + &output_3d_slice, T(0.0)); + } + } + output->ShareDataWith(output_buffer) + .Resize(framework::make_ddim(output_shape_vec)); + } +}; + +template +class DeformableConvV1GradCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* output_grad = + ctx.Input(framework::GradVarName("Output")); + Tensor* input_grad = ctx.Output(framework::GradVarName("Input")); + Tensor* filter_grad = ctx.Output(framework::GradVarName("Filter")); + Tensor* offset_grad = ctx.Output(framework::GradVarName("Offset")); + + const Tensor* input = ctx.Input("Input"); + Tensor offset = *ctx.Input("Offset"); + Tensor filter = *ctx.Input("Filter"); + if (!input_grad && !filter_grad && !offset_grad) return; + + int groups = ctx.Attr("groups"); + int deformable_groups = ctx.Attr("deformable_groups"); + int im2col_step = ctx.Attr("im2col_step"); + std::vector strides = ctx.Attr>("strides"); + std::vector paddings = ctx.Attr>("paddings"); + std::vector dilations = ctx.Attr>("dilations"); + + auto& dev_ctx = ctx.template device_context(); + const int batch_size = static_cast(input->dims()[0]); + + framework::DDim input_shape = + framework::slice_ddim(input->dims(), 1, input->dims().size()); + std::vector input_shape_vec = framework::vectorize(input_shape); + std::vector filter_shape_vec(framework::vectorize(filter.dims())); + std::vector output_shape_vec( + framework::vectorize(output_grad->dims())); + + std::vector col_buffer_shape_vec(filter_shape_vec.size()); + col_buffer_shape_vec[0] = + input->dims()[1] * filter.dims()[2] * filter.dims()[3]; + col_buffer_shape_vec[1] = im2col_step; + for (size_t j = 0; j < filter_shape_vec.size() - 2; ++j) { + col_buffer_shape_vec[j + 2] = output_shape_vec[j + 2]; + } + framework::DDim col_shape(framework::make_ddim(col_buffer_shape_vec)); + std::vector output_buffer_shape_vec(1); + output_buffer_shape_vec[0] = batch_size * output_shape_vec[1] * + output_shape_vec[2] * output_shape_vec[3]; + framework::DDim output_shape(framework::make_ddim(output_buffer_shape_vec)); + Tensor col_buffer; + Tensor output_buffer; + col_buffer = ctx.AllocateTmpTensor(col_shape, dev_ctx); + output_buffer = + ctx.AllocateTmpTensor(output_shape, dev_ctx); + + output_buffer.ShareDataWith(*output_grad); + + int64_t M = + input_shape_vec[0] / groups * filter_shape_vec[2] * filter_shape_vec[3]; + int64_t N = im2col_step * output_shape_vec[2] * output_shape_vec[3]; + int64_t K = output_shape_vec[1] / groups; + + framework::DDim weight_3d_shape = {groups, K, M}; + framework::DDim out_grad_4d_shape = {batch_size / im2col_step, groups, K, + N}; + framework::DDim col_buffer_3d_shape = {groups, M, N}; + framework::DDim filter_grad_shape = {groups, K, M}; + + Tensor weight_3d; + weight_3d.ShareDataWith(filter).Resize(weight_3d_shape); + Tensor out_grad_4d; + out_grad_4d.ShareDataWith(output_buffer).Resize(out_grad_4d_shape); + Tensor col_buffer_3d; + col_buffer_3d.ShareDataWith(col_buffer).Resize(col_buffer_3d_shape); + + math::SetConstant set_zero; + auto blas = math::GetBlas(dev_ctx); + + col_buffer.mutable_data(ctx.GetPlace()); + col_buffer_3d.mutable_data(ctx.GetPlace()); + out_grad_4d.mutable_data(ctx.GetPlace()); + + int input_dim = input->numel() / input->dims()[0]; + int input_offset_dim = offset.numel() / offset.dims()[0]; + + if (filter_grad) { + filter_grad->mutable_data(ctx.GetPlace()); + filter_grad->Resize(filter_grad_shape); + set_zero(dev_ctx, filter_grad, static_cast(0)); + } + + if (input_grad) { + input_grad->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, input_grad, static_cast(0)); + } + + if (offset_grad) { + offset_grad->mutable_data(ctx.GetPlace()); + set_zero(dev_ctx, offset_grad, static_cast(0)); + } + + for (int i = 0; i < batch_size / im2col_step; ++i) { + Tensor out_grad_3d = + out_grad_4d.Slice(i, i + 1).Resize(framework::slice_ddim( + out_grad_4d.dims(), 1, out_grad_4d.dims().size())); + for (int g = 0; g < groups; ++g) { + Tensor weight_3d_slice = + weight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + weight_3d.dims(), 1, weight_3d.dims().size())); + Tensor out_grad_3d_slice = + out_grad_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + + blas.MatMul(weight_3d_slice, true, out_grad_3d_slice, false, T(1.0), + &col_buffer_3d_slice, T(0.0)); + } + col_buffer.Resize(col_shape); + + T* col_buffer_ptr = col_buffer.data(); + const T* input_ptr = input->data(); + const T* offset_ptr = offset.data(); + + if (offset_grad) { + T* offset_grad_ptr = offset_grad->data(); + // get grad of offset + DeformableCol2imCoordCPU( + dev_ctx, col_buffer_ptr, input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, input_shape_vec, + col_buffer_shape_vec, filter_shape_vec, paddings, strides, + dilations, deformable_groups, + offset_grad_ptr + i * im2col_step * input_offset_dim); + } + if (input_grad) { + T* input_grad_ptr = input_grad->data(); + // get grad of input + DeformableCol2imCPU(dev_ctx, col_buffer_ptr, + offset_ptr + i * im2col_step * input_offset_dim, + input_shape_vec, col_buffer_shape_vec, + filter_shape_vec, paddings, strides, dilations, + deformable_groups, + input_grad_ptr + i * im2col_step * input_dim); + input_grad->Resize(input->dims()); + } + + DeformableIm2colCPU(dev_ctx, input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + input_shape_vec, col_buffer_shape_vec, + filter_shape_vec, paddings, strides, dilations, + deformable_groups, col_buffer_ptr); + + col_buffer_3d.Resize(col_buffer_3d_shape); + + if (filter_grad) { + Tensor dweight_3d; + dweight_3d = ctx.AllocateTmpTensor( + filter_grad_shape, dev_ctx); + for (int g = 0; g < groups; ++g) { + Tensor out_grad_3d_slice = + out_grad_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + out_grad_3d.dims(), 1, out_grad_3d.dims().size())); + Tensor col_buffer_3d_slice = + col_buffer_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + col_buffer_3d.dims(), 1, col_buffer_3d.dims().size())); + Tensor dweight_3d_slice = + dweight_3d.Slice(g, g + 1).Resize(framework::slice_ddim( + dweight_3d.dims(), 1, dweight_3d.dims().size())); + + blas.MatMul(out_grad_3d_slice, false, col_buffer_3d_slice, true, + T(1.0), &dweight_3d_slice, T(0.0)); + } + // update grad of weights + FilterGradAddupCPUKernel(dweight_3d.numel(), groups, K, M, + dweight_3d.data(), filter_grad->data()); + } + } + if (filter_grad) { + filter_grad->Resize(filter.dims()); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 000a4d4d30..a4bf137831 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -13196,20 +13196,30 @@ def deformable_conv(input, im2col_step=None, param_attr=None, bias_attr=None, + modulated=True, name=None): """ **Deformable Convolution Layer** Compute 2-D deformable convolution on 4-D input. Given input image x, output feature map y, the deformable convolution operation can be expressed as follow: + + + Deformable Convolution v2: .. math:: y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k} + + Deformable Convolution v1: - Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, respectively. - Refer to `Deformable ConvNets v2: More Deformable, Better Results - `_ . + .. math:: + + y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k)} + + Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, + which :math:`\Delta m_k` is one in deformable convolution v1. Please refer to `Deformable ConvNets v2: More Deformable, Better Results + `_ and `Deformable Convolutional Networks `_. Example: - Input: @@ -13235,7 +13245,7 @@ def deformable_conv(input, Args: input (Variable): The input image with [N, C, H, W] format. - offset (Variable): The input coord offset of deformable convolution layer. + offset (Variable): The input coordinate offset of deformable convolution layer. Mask (Variable): The input mask of deformable covolution layer. num_filters(int): The number of filter. It is as same as the output image channel. @@ -13274,6 +13284,8 @@ def deformable_conv(input, to the output units. If it is set to None or one attribute of ParamAttr, conv2d will create ParamAttr as bias_attr. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. + modulated (bool): Make sure which version should be used between v1 and v2, where v2 is \ + used while True. Default: True. name (str|None): A name for this layer(optional). If set None, the layer will be named automatically. Default: None Returns: @@ -13285,12 +13297,22 @@ def deformable_conv(input, Examples: .. code-block:: python + #deformable conv v2: + import paddle.fluid as fluid data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32') offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32') mask = fluid.layers.data(name='mask', shape=[9, 32, 32], dtype='float32') out = fluid.layers.deformable_conv(input=data, offset=offset, mask=mask, - num_filters=2, filter_size=3, padding=1) + num_filters=2, filter_size=3, padding=1, modulated=True) + + #deformable conv v1: + + import paddle.fluid as fluid + data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32') + offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32') + out = fluid.layers.deformable_conv(input=data, offset=offset, mask=None, + num_filters=2, filter_size=3, padding=1, modulated=False) """ num_channels = input.shape[1] @@ -13303,8 +13325,6 @@ def deformable_conv(input, raise TypeError("Input of deformable_conv must be Variable") if not isinstance(offset, Variable): raise TypeError("Input Offset of deformable_conv must be Variable") - if not isinstance(mask, Variable): - raise TypeError("Input Mask of deformable_conv must be Variable") if groups is None: num_filter_channels = num_channels @@ -13334,23 +13354,42 @@ def deformable_conv(input, pre_bias = helper.create_variable_for_type_inference(dtype) - helper.append_op( - type='deformable_conv', - inputs={ - 'Input': input, - 'Filter': filter_param, - 'Offset': offset, - 'Mask': mask, - }, - outputs={"Output": pre_bias}, - attrs={ - 'strides': stride, - 'paddings': padding, - 'dilations': dilation, - 'groups': groups, - 'deformable_groups': deformable_groups, - 'im2col_step': im2col_step, - }) + if modulated: + helper.append_op( + type='deformable_conv', + inputs={ + 'Input': input, + 'Filter': filter_param, + 'Offset': offset, + 'Mask': mask, + }, + outputs={"Output": pre_bias}, + attrs={ + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'groups': groups, + 'deformable_groups': deformable_groups, + 'im2col_step': im2col_step, + }) + + else: + helper.append_op( + type='deformable_conv_v1', + inputs={ + 'Input': input, + 'Filter': filter_param, + 'Offset': offset, + }, + outputs={"Output": pre_bias}, + attrs={ + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'groups': groups, + 'deformable_groups': deformable_groups, + 'im2col_step': im2col_step, + }) output = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) return output diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py index aacb9ff447..db3b2a8f96 100644 --- a/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_op.py @@ -145,48 +145,35 @@ class TestModulatedDeformableConvOp(OpTest): } self.outputs = {'Output': output} - def has_cuda(self): - return core.is_compiled_with_cuda() - def test_check_output(self): - if self.has_cuda(): - place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5) + self.check_output(atol=1e-5) def test_check_grad(self): - if self.has_cuda(): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, {'Input', 'Offset', 'Mask', 'Filter'}, - 'Output', - max_relative_error=0.05) + self.check_grad( + {'Input', 'Offset', 'Mask', 'Filter'}, + 'Output', + max_relative_error=0.05) def test_check_grad_no_filter(self): - if self.has_cuda(): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, ['Input', 'Offset', 'Mask'], - 'Output', - max_relative_error=0.1, - no_grad_set=set(['Filter'])) + self.check_grad( + ['Input', 'Offset', 'Mask'], + 'Output', + max_relative_error=0.1, + no_grad_set=set(['Filter'])) def test_check_grad_no_input(self): - if self.has_cuda(): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, ['Filter', 'Offset', 'Mask'], - 'Output', - max_relative_error=0.1, - no_grad_set=set(['Input'])) + self.check_grad( + ['Filter', 'Offset', 'Mask'], + 'Output', + max_relative_error=0.1, + no_grad_set=set(['Input'])) def test_check_grad_no_offset_no_mask(self): - if self.has_cuda(): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, ['Input', 'Filter'], - 'Output', - max_relative_error=0.1, - no_grad_set=set(['Offset', 'Mask'])) + self.check_grad( + ['Input', 'Filter'], + 'Output', + max_relative_error=0.1, + no_grad_set=set(['Offset', 'Mask'])) def init_test_case(self): self.pad = [1, 1] diff --git a/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py b/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py new file mode 100644 index 0000000000..5646f72f72 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_deformable_conv_v1_op.py @@ -0,0 +1,240 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle.fluid.core as core +from op_test import OpTest + + +def dmc_bilinear(data_im, height, width, h, w): + h_low = int(np.floor(h)) + w_low = int(np.floor(w)) + h_high = h_low + 1 + w_high = w_low + 1 + + lh = h - h_low + lw = w - w_low + hh = 1 - lh + hw = 1 - lw + + v1 = 0 + if h_low >= 0 and w_low >= 0: + v1 = data_im[h_low, w_low] + v2 = 0 + if h_low >= 0 and w_high <= width - 1: + v2 = data_im[h_low, w_high] + v3 = 0 + if h_high <= height - 1 and w_low >= 0: + v3 = data_im[h_high, w_low] + v4 = 0 + if h_high <= height - 1 and w_high <= width - 1: + v4 = data_im[h_high, w_high] + + w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw + val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4 + + return val + + +def dconv_im2col_gemm(input, offset, filter, group, conv_param): + in_n, in_c, in_h, in_w = input.shape + out_c, f_c, f_h, f_w = filter.shape + + assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w) + assert f_c * group == in_c + assert np.mod(out_c, group) == 0 + + stride, pad, dilation = conv_param['stride'], conv_param['pad'],\ + conv_param['dilation'] + out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0] + out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1] + assert out_h == in_h + assert out_w == in_w + + col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w)) + for n in range(in_n): + for c in range(in_c): + for h in range(out_h): + for w in range(out_w): + for kh in range(f_h): + for kw in range(f_w): + offset_h_table = \ + offset[n, ::2, h, w].reshape(f_h, f_w) + offset_w_table = \ + offset[n, 1::2, h, w].reshape(f_h, f_w) + offset_h = offset_h_table[kh, kw] + offset_w = offset_w_table[kh, kw] + val = 0 + im_h = h * stride[0] + kh * dilation[0] \ + + offset_h - pad[0] + im_w = w * stride[0] + kw * dilation[0] \ + + offset_w - pad[1] + if im_h > -1 and im_w > -1 and \ + im_h < in_h and im_w < in_h: + val = dmc_bilinear(input[n, c], in_h, in_w, + im_h, im_w) + val_out = val + + col_buffer[n, c * f_h * f_w + kh * f_w + kw, h * + in_w + w] = val_out + + out = np.zeros((in_n, group, int(out_c // group), out_h * out_w)) + weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w) + col_buffer = col_buffer.reshape( + (in_n, group, int(in_c // group * f_h * f_w), in_h * in_w)) + for n in range(in_n): + for g in range(group): + out[n, g] = np.matmul(weight[g], col_buffer[n, g]) + out = out.reshape(in_n, out_c, out_h, out_w) + return out + + +class TestModulatedDeformableConvOp(OpTest): + def setUp(self): + self.op_type = "deformable_conv_v1" + self.dtype = np.float32 + self.init_group() + self.init_dilation() + self.init_test_case() + + conv_param = { + 'stride': self.stride, + 'pad': self.pad, + 'dilation': self.dilations + } + + input = np.random.random(self.input_size).astype(self.dtype) + offset = 10 * np.random.random(self.offset_size).astype(self.dtype) + filter = np.random.random(self.filter_size).astype(self.dtype) + + output = dconv_im2col_gemm(input, offset, filter, self.groups, + conv_param) + output = output.astype(self.dtype) + self.inputs = { + 'Input': OpTest.np_dtype_to_fluid_dtype(input), + 'Offset': OpTest.np_dtype_to_fluid_dtype(offset), + 'Filter': OpTest.np_dtype_to_fluid_dtype(filter) + } + self.attrs = { + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'deformable_groups': self.deformable_groups, + 'im2col_step': self.im2col_step, + 'dilations': self.dilations, + } + self.outputs = {'Output': output} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['Input', 'Offset', 'Filter'], 'Output', max_relative_error=0.05) + + def test_check_grad_no_filter(self): + self.check_grad( + ['Input', 'Offset'], + 'Output', + max_relative_error=0.1, + no_grad_set=set(['Filter'])) + + def init_test_case(self): + self.pad = [1, 1] + self.stride = [1, 1] + self.dilations = [1, 1] + self.input_size = [2, 4, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [4, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [1, 1] + + def init_group(self): + self.groups = 1 + + +class TestWithStride(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [3, 3] + self.stride = [2, 2] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + + +class TestWithDilation(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [2, 2] + self.stride = [1, 1] + self.input_size = [2, 3, 4, 4] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 3, 3] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + + def init_dilation(self): + self.dilations = [2, 2] + + +class TestWith1x1(TestModulatedDeformableConvOp): + def init_test_case(self): + self.pad = [0, 0] + self.stride = [1, 1] + self.input_size = [2, 3, 5, 5] # NCHW + assert np.mod(self.input_size[1], self.groups) == 0 + f_c = self.input_size[1] // self.groups + self.filter_size = [6, f_c, 1, 1] + self.im2col_step = 1 + self.deformable_groups = 1 + offset_c = 2 * self.deformable_groups * self.filter_size[ + 2] * self.filter_size[3] + self.offset_size = [ + self.input_size[0], offset_c, self.input_size[2], self.input_size[3] + ] + + +class TestWithGroup(TestModulatedDeformableConvOp): + def init_group(self): + self.groups = 2 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 47779859a0..ad8a42700e 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2276,32 +2276,31 @@ class TestBook(LayerTest): print(str(program)) def test_deformable_conv(self): - if core.is_compiled_with_cuda(): - with program_guard(fluid.default_main_program(), - fluid.default_startup_program()): - input = layers.data( - name='input', - append_batch_size=False, - shape=[2, 3, 32, 32], - dtype="float32") - offset = layers.data( - name='offset', - append_batch_size=False, - shape=[2, 18, 32, 32], - dtype="float32") - mask = layers.data( - name='mask', - append_batch_size=False, - shape=[2, 9, 32, 32], - dtype="float32") - out = layers.deformable_conv( - input=input, - offset=offset, - mask=mask, - num_filters=2, - filter_size=3, - padding=1) - return (out) + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = layers.data( + name='input', + append_batch_size=False, + shape=[2, 3, 32, 32], + dtype="float32") + offset = layers.data( + name='offset', + append_batch_size=False, + shape=[2, 18, 32, 32], + dtype="float32") + mask = layers.data( + name='mask', + append_batch_size=False, + shape=[2, 9, 32, 32], + dtype="float32") + out = layers.deformable_conv( + input=input, + offset=offset, + mask=mask, + num_filters=2, + filter_size=3, + padding=1) + return (out) def test_unfold(self): with self.static_graph(): @@ -2338,6 +2337,29 @@ class TestBook(LayerTest): trans_std=0.1) return (out) + def test_deformable_conv_v1(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + input = layers.data( + name='input', + append_batch_size=False, + shape=[2, 3, 32, 32], + dtype="float32") + offset = layers.data( + name='offset', + append_batch_size=False, + shape=[2, 18, 32, 32], + dtype="float32") + out = layers.deformable_conv( + input=input, + offset=offset, + mask=None, + num_filters=2, + filter_size=3, + padding=1, + modulated=False) + return (out) + def test_retinanet_target_assign(self): with program_guard(fluid.default_main_program(), fluid.default_startup_program()):