Add deformable conv v2 op,test=develop (#17145)

* unit commits, test=develop

* update API.spec, test=develop
dependabot/pip/python/requests-2.20.0
Bai Yifan 6 years ago committed by GitHub
parent aca53535e0
commit bba57cdd82
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -110,7 +110,7 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "deformable_conv_op" "dgc_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()

@ -237,6 +237,7 @@ paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=
paddle.fluid.layers.continuous_value_model (ArgSpec(args=['input', 'cvm', 'use_cvm'], varargs=None, keywords=None, defaults=(True,)), ('document', '94e2819b7c9715ea71b62e9c78f36b29'))
paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=None, defaults=None), ('document', '3126e3039e752ce26077f1efaca355c6'))
paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'ccf6bb7912afd2818d24bc45461e807a'))
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', 'c896b66265a60bd3c5510f66e6e02919'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', 'adf285346e23316097f7789b572491e9'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'dce69a78638da8f7ad80b1fc00ed2029'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6'))

@ -47,7 +47,8 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op sync_batch_norm_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
sync_batch_norm_op deformable_conv_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above
@ -65,6 +66,8 @@ if (WITH_GPU)
op_library(sync_batch_norm_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(sync_batch_norm);\n")
endif()
op_library(deformable_conv_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(deformable_conv);\n")
else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -202,6 +202,7 @@ __all__ = [
'continuous_value_model',
'where',
'sign',
'deformable_conv',
]
kIgnoreIndex = -100
@ -11745,3 +11746,175 @@ def sign(x):
helper.append_op(type='sign', inputs={'X': [x]}, outputs={'Out': [out]})
return out
def deformable_conv(input,
offset,
mask,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
deformable_groups=None,
im2col_step=None,
param_attr=None,
bias_attr=None,
name=None):
"""
**Deformable Convolution Layer**
Compute 2-D deformable convolution on 4-D input.
Given input image x, output feature map y, the deformable convolution operation can be expressed as follow:
.. math::
y(p) = \sum_{k=1}^{K}{w_k * x(p + p_k + \Delta p_k) * \Delta m_k}
Where :math:`\Delta p_k` and :math:`\Delta m_k` are the learnable offset and modulation scalar for the k-th location, respectively.
Refer to `Deformable ConvNets v2: More Deformable, Better Results
<https://arxiv.org/abs/1811.11168v2>`_ .
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
Offset shape: :math:`(N, 2 * deformable\_groups * H_f * H_w, H_{in}, W_{in})`
Mask shape: :math:`(N, deformable\_groups * H_f * H_w, H_{in}, W_{in})`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
input (Variable): The input image with [N, C, H, W] format.
offset (Variable): The input coord offset of deformable convolution layer.
Mask (Variable): The input mask of deformable covolution layer.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups (int): The groups number of the deformable conv layer. According to
grouped convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1.
deformable_groups (int): The number of deformable group partitions.
Default: deformable_groups = 1.
im2col_step (int): Maximum number of images per im2col computation;
The total batch size should be divisable by this value or smaller
than this value; if you face out of memory problem, you can try
to use a smaller value here.
Default: im2col_step = 64.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of deformable conv. If it is set to None or one attribute of ParamAttr,
deformable conv will create ParamAttr as param_attr.
If the Initializer of the param_attr is not set, the parameter is
initialized with :math:`Normal(0.0, std)`, and the
:math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of
deformable conv layer. If it is set to False, no bias will be added
to the output units. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: The tensor variable storing the deformable convolution \
result.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
offset = fluid.layers.data(name='offset', shape=[18, 32, 32], dtype='float32')
mask = fluid.layers.data(name='mask', shape=[9, 32, 32], dtype='float32')
out = fluid.layers.deformable_conv(input=data, offset=offset, mask=mask,
num_filters=2, filter_size=3, padding=1)
"""
num_channels = input.shape[1]
assert param_attr is not False, "param_attr should not be False here."
helper = LayerHelper('deformable_conv', **locals())
dtype = helper.input_dtype()
if not isinstance(input, Variable):
raise TypeError("Input of deformable_conv must be Variable")
if not isinstance(offset, Variable):
raise TypeError("Input Offset of deformable_conv must be Variable")
if not isinstance(mask, Variable):
raise TypeError("Input Mask of deformable_conv must be Variable")
if groups is None:
num_filter_channels = num_channels
else:
if num_channels % groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels // groups
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
padding = utils.convert_to_list(padding, 2, 'padding')
dilation = utils.convert_to_list(dilation, 2, 'dilation')
input_shape = input.shape
filter_shape = [num_filters, int(num_filter_channels)] + filter_size
def _get_default_param_initializer():
filter_elem_num = filter_size[0] * filter_size[1] * num_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype,
default_initializer=_get_default_param_initializer())
pre_bias = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='deformable_conv',
inputs={
'Input': input,
'Filter': filter_param,
'Offset': offset,
'Mask': mask,
},
outputs={"Output": pre_bias},
attrs={
'strides': stride,
'paddings': padding,
'dilations': dilation,
'groups': groups,
'deformable_groups': deformable_groups,
'im2col_step': im2col_step,
})
output = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
return output

@ -1957,6 +1957,34 @@ class TestBook(LayerTest):
self.assertIsNotNone(out)
print(str(program))
def test_deformable_conv(self):
if core.is_compiled_with_cuda():
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='input',
append_batch_size=False,
shape=[2, 3, 32, 32],
dtype="float32")
offset = layers.data(
name='offset',
append_batch_size=False,
shape=[2, 18, 32, 32],
dtype="float32")
mask = layers.data(
name='mask',
append_batch_size=False,
shape=[2, 9, 32, 32],
dtype="float32")
out = layers.deformable_conv(
input=input,
offset=offset,
mask=mask,
num_filters=2,
filter_size=3,
padding=1)
return (out)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save