add deformable conv v1 op and cpu version of deformable conv v2 (#18500)

* add deformable conv v1 op, test=develop
expand_as_op_1
chengjuntao 5 years ago committed by GitHub
parent 40c66f8df9
commit 00efd1d8a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op") "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}") if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1) set(pybind_flag 1)
endif() endif()

@ -285,7 +285,7 @@ paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=
paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'c03490ffaa1b78258747157c313db4cd')) paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', 'c03490ffaa1b78258747157c313db4cd'))
paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', 'b1e1487760295e1ff55307b880a99e18')) paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', 'b1e1487760295e1ff55307b880a99e18'))
paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'fa2f457a81714430c5677c2d68744728')) paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'fa2f457a81714430c5677c2d68744728'))
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e')) paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'modulated', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, True, None)), ('document', '335193ac57d41d7199f8d26d30c069b1'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6')) paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35')) paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', 'b6ea7d4ddeacae85e37d1e47d5262948')) paddle.fluid.layers.match_matrix_tensor (ArgSpec(args=['x', 'y', 'channel_num', 'act', 'param_attr', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, 'float32', None)), ('document', 'b6ea7d4ddeacae85e37d1e47d5262948'))

@ -55,7 +55,7 @@ if (NOT WITH_MKL)
endif() endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
sync_batch_norm_op deformable_conv_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) sync_batch_norm_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU) if (WITH_GPU)
# warpctc_op needs cudnn 7 above # warpctc_op needs cudnn 7 above
@ -73,8 +73,6 @@ if (WITH_GPU)
op_library(sync_batch_norm_op) op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n") file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
endif() endif()
op_library(deformable_conv_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(deformable_conv);\n")
else() else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()

@ -0,0 +1,37 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
//
// Copyright (c) 2017 Microsoft
// Licensed under The Apache-2.0 License [see LICENSE for details]
// \file deformable_psroi_pooling.cu
// \brief
// \author Yi Li, Guodong Zhang, Jifeng Dai
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
template <typename T>
__global__ void FilterGradAddupCUDAKernel(const int nthreads, const int n,
const int height, const int width,
const T* dweight_3d, T* filter_grad) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (size_t i = index; i < nthreads; i += offset) {
filter_grad[i] = filter_grad[i] + dweight_3d[i];
}
}

@ -0,0 +1,149 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/msracver/Deformable-ConvNets/blob/master/faster_rcnn/operator_cxx/deformable_convolution.cu
//
// Copyright (c) 2017 Microsoft
// Licensed under The Apache-2.0 License [see LICENSE for details]
// \file deformable_psroi_pooling.cu
// \brief
// \author Yi Li, Guodong Zhang, Jifeng Dai
#pragma once
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/hostdevice.h"
template <typename T>
HOSTDEVICE T DmcnGetGradientWeight(T argmax_h, T argmax_w, const int h,
const int w, const int height,
const int width) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
weight = (h == argmax_h_low && w == argmax_w_low)
? (h + 1 - argmax_h) * (w + 1 - argmax_w)
: weight;
weight = (h == argmax_h_low && w == argmax_w_high)
? (h + 1 - argmax_h) * (argmax_w + 1 - w)
: weight;
weight = (h == argmax_h_high && w == argmax_w_low)
? (argmax_h + 1 - h) * (w + 1 - argmax_w)
: weight;
weight = (h == argmax_h_high && w == argmax_w_high)
? (argmax_h + 1 - h) * (argmax_w + 1 - w)
: weight;
return weight;
}
template <typename T>
HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h, T argmax_w, const int height,
const int width, const T* im_data,
const int data_width, const int bp_dir) {
if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 ||
argmax_w >= width) {
return 0;
}
int argmax_h_low = floor(argmax_h);
int argmax_w_low = floor(argmax_w);
int argmax_h_high = argmax_h_low + 1;
int argmax_w_high = argmax_w_low + 1;
T weight = 0;
if (bp_dir == 0) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_low * data_width + argmax_w_low]
: 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? -1 * (argmax_w - argmax_w_low) *
im_data[argmax_h_low * data_width + argmax_w_high]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? (argmax_w_low + 1 - argmax_w) *
im_data[argmax_h_high * data_width + argmax_w_low]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_w - argmax_w_low) *
im_data[argmax_h_high * data_width + argmax_w_high]
: 0;
} else if (bp_dir == 1) {
weight += (argmax_h_low >= 0 && argmax_w_low >= 0)
? -1 * (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_low]
: 0;
weight += (argmax_h_low >= 0 && argmax_w_high <= width - 1)
? (argmax_h_low + 1 - argmax_h) *
im_data[argmax_h_low * data_width + argmax_w_high]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_low >= 0)
? -1 * (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_low]
: 0;
weight += (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
? (argmax_h - argmax_h_low) *
im_data[argmax_h_high * data_width + argmax_w_high]
: 0;
}
return weight;
}
template <typename T>
HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, const int data_width,
const int height, const int width, T h, T w) {
int h_low = floor(h);
int w_low = floor(w);
int h_high = h_low + 1;
int w_high = w_low + 1;
T lh = h - h_low;
T lw = w - w_low;
T hh = 1 - lh;
T hw = 1 - lw;
T v1 =
(h_low >= 0 && w_low >= 0) ? bottom_data[h_low * data_width + w_low] : 0;
T v2 = (h_low >= 0 && w_high <= width - 1)
? bottom_data[h_low * data_width + w_high]
: 0;
T v3 = (h_high <= height - 1 && w_low >= 0)
? bottom_data[h_high * data_width + w_low]
: 0;
T v4 = (h_high <= height - 1 && w_high <= width - 1)
? bottom_data[h_high * data_width + w_high]
: 0;
T w1 = hh * hw;
T w2 = hh * lw;
T w3 = lh * hw;
T w4 = lh * lw;
return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}

@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/deformable_conv_op.h"
#include <memory>
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
namespace paddle { namespace paddle {
@ -197,7 +199,6 @@ class DeformableConvOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]), PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]),
deformable_groups, deformable_groups,
"mask filter must divide deformable group size."); "mask filter must divide deformable group size.");
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
} }
@ -274,5 +275,10 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp, REGISTER_OPERATOR(deformable_conv, ops::DeformableConvOp,
ops::DeformableConvOpMaker, ops::DeformableConvOpMaker,
ops::DeformableConvGradOpDescMaker); ops::DeformableConvGradOpDescMaker);
REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp); REGISTER_OPERATOR(deformable_conv_grad, ops::DeformableConvGradOp);
REGISTER_OP_CPU_KERNEL(deformable_conv, ops::DeformableConvCPUKernel<float>,
ops::DeformableConvCPUKernel<double>);
REGISTER_OP_CPU_KERNEL(deformable_conv_grad,
ops::DeformableConvGradCPUKernel<float>,
ops::DeformableConvGradCPUKernel<double>);

@ -24,6 +24,7 @@
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/deformable_conv_op.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/cuda_primitives.h"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -13196,6 +13196,7 @@ def deformable_conv(input,
im2col_step=None, im2col_step=None,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
modulated=True,
name=None): name=None):
""" """
**Deformable Convolution Layer** **Deformable Convolution Layer**
@ -13203,13 +13204,22 @@ def deformable_conv(input,
Compute 2-D deformable convolution on 4-D input. Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow: Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
Deformable Convolution v2:
.. math:: .. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k} y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k}
Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, respectively. Deformable Convolution v1:
Refer to `Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168v2>`_ . .. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k)}
Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location,
which :math:`\Delta m_k` is one in deformable convolution v1. Please refer to `Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168v2>`_ and `Deformable Convolutional Networks <https://arxiv.org/abs/1703.06211>`_.
Example: Example:
- Input: - Input:
@ -13235,7 +13245,7 @@ def deformable_conv(input,
Args: Args:
input (Variable): The input image with [N, C, H, W] format. input (Variable): The input image with [N, C, H, W] format.
offset (Variable): The input coord offset of deformable convolution layer. offset (Variable): The input coordinate offset of deformable convolution layer.
Mask (Variable): The input mask of deformable covolution layer. Mask (Variable): The input mask of deformable covolution layer.
num_filters(int): The number of filter. It is as same as the output num_filters(int): The number of filter. It is as same as the output
image channel. image channel.
@ -13274,6 +13284,8 @@ def deformable_conv(input,
to the output units. If it is set to None or one attribute of ParamAttr, conv2d to the output units. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None. is not set, the bias is initialized zero. Default: None.
modulated (bool): Make sure which version should be used between v1 and v2, where v2 is \
used while True. Default: True.
name (str|None): A name for this layer(optional). If set None, the layer name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None will be named automatically. Default: None
Returns: Returns:
@ -13285,12 +13297,22 @@ def deformable_conv(input,
Examples: Examples:
.. code-block:: python .. code-block:: python
#deformable conv v2:
import paddle.fluid as fluid import paddle.fluid as fluid
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32') data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32') offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32')
mask = fluid.layers.data(name='mask', shape=[9, 32, 32], dtype='float32') mask = fluid.layers.data(name='mask', shape=[9, 32, 32], dtype='float32')
out = fluid.layers.deformable_conv(input=data, offset=offset, mask=mask, out = fluid.layers.deformable_conv(input=data, offset=offset, mask=mask,
num_filters=2, filter_size=3, padding=1) num_filters=2, filter_size=3, padding=1, modulated=True)
#deformable conv v1:
import paddle.fluid as fluid
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32')
out = fluid.layers.deformable_conv(input=data, offset=offset, mask=None,
num_filters=2, filter_size=3, padding=1, modulated=False)
""" """
num_channels = input.shape[1] num_channels = input.shape[1]
@ -13303,8 +13325,6 @@ def deformable_conv(input,
raise TypeError("Input of deformable_conv must be Variable") raise TypeError("Input of deformable_conv must be Variable")
if not isinstance(offset, Variable): if not isinstance(offset, Variable):
raise TypeError("Input Offset of deformable_conv must be Variable") raise TypeError("Input Offset of deformable_conv must be Variable")
if not isinstance(mask, Variable):
raise TypeError("Input Mask of deformable_conv must be Variable")
if groups is None: if groups is None:
num_filter_channels = num_channels num_filter_channels = num_channels
@ -13334,23 +13354,42 @@ def deformable_conv(input,
pre_bias = helper.create_variable_for_type_inference(dtype) pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op( if modulated:
type='deformable_conv', helper.append_op(
inputs={ type='deformable_conv',
'Input': input, inputs={
'Filter': filter_param, 'Input': input,
'Offset': offset, 'Filter': filter_param,
'Mask': mask, 'Offset': offset,
}, 'Mask': mask,
outputs={"Output": pre_bias}, },
attrs={ outputs={"Output": pre_bias},
'strides': stride, attrs={
'paddings': padding, 'strides': stride,
'dilations': dilation, 'paddings': padding,
'groups': groups, 'dilations': dilation,
'deformable_groups': deformable_groups, 'groups': groups,
'im2col_step': im2col_step, 'deformable_groups': deformable_groups,
}) 'im2col_step': im2col_step,
})
else:
helper.append_op(
type='deformable_conv_v1',
inputs={
'Input': input,
'Filter': filter_param,
'Offset': offset,
},
outputs={"Output": pre_bias},
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'deformable_groups': deformable_groups,
'im2col_step': im2col_step,
})
output = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2) output = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
return output return output

@ -145,48 +145,35 @@ class TestModulatedDeformableConvOp(OpTest):
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
def has_cuda(self):
return core.is_compiled_with_cuda()
def test_check_output(self): def test_check_output(self):
if self.has_cuda(): self.check_output(atol=1e-5)
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
def test_check_grad(self): def test_check_grad(self):
if self.has_cuda(): self.check_grad(
place = core.CUDAPlace(0) {'Input', 'Offset', 'Mask', 'Filter'},
self.check_grad_with_place( 'Output',
place, {'Input', 'Offset', 'Mask', 'Filter'}, max_relative_error=0.05)
'Output',
max_relative_error=0.05)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
if self.has_cuda(): self.check_grad(
place = core.CUDAPlace(0) ['Input', 'Offset', 'Mask'],
self.check_grad_with_place( 'Output',
place, ['Input', 'Offset', 'Mask'], max_relative_error=0.1,
'Output', no_grad_set=set(['Filter']))
max_relative_error=0.1,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
if self.has_cuda(): self.check_grad(
place = core.CUDAPlace(0) ['Filter', 'Offset', 'Mask'],
self.check_grad_with_place( 'Output',
place, ['Filter', 'Offset', 'Mask'], max_relative_error=0.1,
'Output', no_grad_set=set(['Input']))
max_relative_error=0.1,
no_grad_set=set(['Input']))
def test_check_grad_no_offset_no_mask(self): def test_check_grad_no_offset_no_mask(self):
if self.has_cuda(): self.check_grad(
place = core.CUDAPlace(0) ['Input', 'Filter'],
self.check_grad_with_place( 'Output',
place, ['Input', 'Filter'], max_relative_error=0.1,
'Output', no_grad_set=set(['Offset', 'Mask']))
max_relative_error=0.1,
no_grad_set=set(['Offset', 'Mask']))
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]

@ -0,0 +1,240 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
def dmc_bilinear(data_im, height, width, h, w):
h_low = int(np.floor(h))
w_low = int(np.floor(w))
h_high = h_low + 1
w_high = w_low + 1
lh = h - h_low
lw = w - w_low
hh = 1 - lh
hw = 1 - lw
v1 = 0
if h_low >= 0 and w_low >= 0:
v1 = data_im[h_low, w_low]
v2 = 0
if h_low >= 0 and w_high <= width - 1:
v2 = data_im[h_low, w_high]
v3 = 0
if h_high <= height - 1 and w_low >= 0:
v3 = data_im[h_high, w_low]
v4 = 0
if h_high <= height - 1 and w_high <= width - 1:
v4 = data_im[h_high, w_high]
w1, w2, w3, w4 = hh * hw, hh * lw, lh * hw, lh * lw
val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
return val
def dconv_im2col_gemm(input, offset, filter, group, conv_param):
in_n, in_c, in_h, in_w = input.shape
out_c, f_c, f_h, f_w = filter.shape
assert offset.shape == (in_n, 2 * f_h * f_w, in_h, in_w)
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
stride, pad, dilation = conv_param['stride'], conv_param['pad'],\
conv_param['dilation']
out_h = 1 + (in_h + 2 * pad[0] - (dilation[0] * (f_h - 1) + 1)) // stride[0]
out_w = 1 + (in_w + 2 * pad[1] - (dilation[1] * (f_w - 1) + 1)) // stride[1]
assert out_h == in_h
assert out_w == in_w
col_buffer = np.zeros((in_n, in_c * f_h * f_w, in_h * in_w))
for n in range(in_n):
for c in range(in_c):
for h in range(out_h):
for w in range(out_w):
for kh in range(f_h):
for kw in range(f_w):
offset_h_table = \
offset[n, ::2, h, w].reshape(f_h, f_w)
offset_w_table = \
offset[n, 1::2, h, w].reshape(f_h, f_w)
offset_h = offset_h_table[kh, kw]
offset_w = offset_w_table[kh, kw]
val = 0
im_h = h * stride[0] + kh * dilation[0] \
+ offset_h - pad[0]
im_w = w * stride[0] + kw * dilation[0] \
+ offset_w - pad[1]
if im_h > -1 and im_w > -1 and \
im_h < in_h and im_w < in_h:
val = dmc_bilinear(input[n, c], in_h, in_w,
im_h, im_w)
val_out = val
col_buffer[n, c * f_h * f_w + kh * f_w + kw, h *
in_w + w] = val_out
out = np.zeros((in_n, group, int(out_c // group), out_h * out_w))
weight = filter.reshape(group, int(out_c // group), f_c * f_h * f_w)
col_buffer = col_buffer.reshape(
(in_n, group, int(in_c // group * f_h * f_w), in_h * in_w))
for n in range(in_n):
for g in range(group):
out[n, g] = np.matmul(weight[g], col_buffer[n, g])
out = out.reshape(in_n, out_c, out_h, out_w)
return out
class TestModulatedDeformableConvOp(OpTest):
def setUp(self):
self.op_type = "deformable_conv_v1"
self.dtype = np.float32
self.init_group()
self.init_dilation()
self.init_test_case()
conv_param = {
'stride': self.stride,
'pad': self.pad,
'dilation': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
offset = 10 * np.random.random(self.offset_size).astype(self.dtype)
filter = np.random.random(self.filter_size).astype(self.dtype)
output = dconv_im2col_gemm(input, offset, filter, self.groups,
conv_param)
output = output.astype(self.dtype)
self.inputs = {
'Input': OpTest.np_dtype_to_fluid_dtype(input),
'Offset': OpTest.np_dtype_to_fluid_dtype(offset),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups,
'deformable_groups': self.deformable_groups,
'im2col_step': self.im2col_step,
'dilations': self.dilations,
}
self.outputs = {'Output': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['Input', 'Offset', 'Filter'], 'Output', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['Input', 'Offset'],
'Output',
max_relative_error=0.1,
no_grad_set=set(['Filter']))
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 4, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [4, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self):
self.groups = 1
class TestWithStride(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [3, 3]
self.stride = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
class TestWithDilation(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [2, 2]
self.stride = [1, 1]
self.input_size = [2, 3, 4, 4] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 3, 3]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
def init_dilation(self):
self.dilations = [2, 2]
class TestWith1x1(TestModulatedDeformableConvOp):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] // self.groups
self.filter_size = [6, f_c, 1, 1]
self.im2col_step = 1
self.deformable_groups = 1
offset_c = 2 * self.deformable_groups * self.filter_size[
2] * self.filter_size[3]
self.offset_size = [
self.input_size[0], offset_c, self.input_size[2], self.input_size[3]
]
class TestWithGroup(TestModulatedDeformableConvOp):
def init_group(self):
self.groups = 2
if __name__ == '__main__':
unittest.main()

@ -2276,32 +2276,31 @@ class TestBook(LayerTest):
print(str(program)) print(str(program))
def test_deformable_conv(self): def test_deformable_conv(self):
if core.is_compiled_with_cuda(): with program_guard(fluid.default_main_program(),
with program_guard(fluid.default_main_program(), fluid.default_startup_program()):
fluid.default_startup_program()): input = layers.data(
input = layers.data( name='input',
name='input', append_batch_size=False,
append_batch_size=False, shape=[2, 3, 32, 32],
shape=[2, 3, 32, 32], dtype="float32")
dtype="float32") offset = layers.data(
offset = layers.data( name='offset',
name='offset', append_batch_size=False,
append_batch_size=False, shape=[2, 18, 32, 32],
shape=[2, 18, 32, 32], dtype="float32")
dtype="float32") mask = layers.data(
mask = layers.data( name='mask',
name='mask', append_batch_size=False,
append_batch_size=False, shape=[2, 9, 32, 32],
shape=[2, 9, 32, 32], dtype="float32")
dtype="float32") out = layers.deformable_conv(
out = layers.deformable_conv( input=input,
input=input, offset=offset,
offset=offset, mask=mask,
mask=mask, num_filters=2,
num_filters=2, filter_size=3,
filter_size=3, padding=1)
padding=1) return (out)
return (out)
def test_unfold(self): def test_unfold(self):
with self.static_graph(): with self.static_graph():
@ -2338,6 +2337,29 @@ class TestBook(LayerTest):
trans_std=0.1) trans_std=0.1)
return (out) return (out)
def test_deformable_conv_v1(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='input',
append_batch_size=False,
shape=[2, 3, 32, 32],
dtype="float32")
offset = layers.data(
name='offset',
append_batch_size=False,
shape=[2, 18, 32, 32],
dtype="float32")
out = layers.deformable_conv(
input=input,
offset=offset,
mask=None,
num_filters=2,
filter_size=3,
padding=1,
modulated=False)
return (out)
def test_retinanet_target_assign(self): def test_retinanet_target_assign(self):
with program_guard(fluid.default_main_program(), with program_guard(fluid.default_main_program(),
fluid.default_startup_program()): fluid.default_startup_program()):

Loading…
Cancel
Save