Add var_conv_2d op (#18518)

* fix overflow by int32 mul test=develop

* fix reference nullptr

* fix codestyle test=develop

* modify to point in ContextProjectFunctor test=develop

* modify to point in ContextProjectFunctor test=develop

* modify . to -> test=develop

* add var_conv_2d op test=develop

* edit api.spec test=develop

* ignore unittest if with_mkl=off test=develop

* fix python3 division test=develop

* fix ignore unittest bug test=develop

* remove useless code test=develop

* modify api.spec test=develop

* modify default_grad.spec test=develop
padding_in_crf
Kevin 6 years ago committed by Tao Luo
parent 81fe02c3fe
commit e681d65515

@ -267,6 +267,7 @@ paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defau
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', '4d83ba6b971cfd590493b0925b3e081e'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '99c03e3f249e36854f87dedaa17c8f35'))
paddle.fluid.layers.var_conv_2d (ArgSpec(args=['input', 'row', 'col', 'input_channel', 'output_channel', 'filter_size', 'stride', 'param_attr', 'act', 'dtype', 'name'], varargs=None, keywords=None, defaults=(1, None, None, 'float32', None)), ('document', '7a8b8ade5512c95f9ea30261d33ded6c'))
paddle.fluid.layers.shard_index (ArgSpec(args=['input', 'index_num', 'nshards', 'shard_id', 'ignore_value'], varargs=None, keywords=None, defaults=(-1,)), ('document', '5786fdbba6753ecd6cbce5e6b0889924'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '9d7806e31bdf727c1a23b8782a09b545'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'cccb6eb5410c822e5307c947aca2c899'))

@ -41,3 +41,4 @@ tensor_array_to_tensor
transpose
unpool
unsqueeze
var_conv_2d

@ -48,8 +48,13 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
SET(OP_ONLY_MKL "")
if (NOT WITH_MKL)
SET(OP_ONLY_MKL ${OP_ONLY_MKL} var_conv_2d_op)
endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op conv_fusion_op
sync_batch_norm_op deformable_conv_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
sync_batch_norm_op deformable_conv_op ${OP_ONLY_MKL} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
if (WITH_GPU)
# warpctc_op needs cudnn 7 above

File diff suppressed because it is too large Load Diff

@ -0,0 +1,45 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
class VarConv2dOP : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
};
class VarConv2dOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
};
class VarConv2dOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle

@ -210,6 +210,7 @@ __all__ = [
'deformable_conv',
'unfold',
'deformable_roi_pooling',
'var_conv_2d',
'shard_index',
]
@ -12729,6 +12730,121 @@ def deformable_roi_pooling(input,
return output
def var_conv_2d(input,
row,
col,
input_channel,
output_channel,
filter_size,
stride=1,
param_attr=None,
act=None,
dtype='float32',
name=None):
"""
The var_conv_2d layer calculates the output base on the :attr:`input` with variable length,
row, col, input channel, filter size and strides. Both :attr:`input`, :attr:`row`,
and :attr:`col` are 1-level LodTensor. The covolution operation is same as conv2d layer with
padding. Besides, input.dims[1] should be 1.
.. code-block:: text
If input_channel is 2 and given row lodTensor and col lodTensor as follows:
row.lod = [[5, 4]]
col.lod = [[6, 7]]
input is a lodTensor:
input.lod = [[60, 56]] # where 60 = input_channel * 5 * 6
input.dims = [116, 1] # where 116 = 60 + 56
If set output_channel is 3, filter_size is [3, 3], stride is [1, 1]:
output.lod = [[90, 84]] # where 90 = output_channel * [(5-1)/stride + 1] * [(6-1)/stride + 1]
output.dims = [174, 1] # where 174 = 90 + 84
Args:
input (Variable): The input shoud be 1-level LodTensor with dims[1] equals 1.
row (Variable): The row shoud be 1-level LodTensor to provide height information.
col (Variable): The col shoud be 1-level LodTensor to provide width information.
input_channel (int): The number of input channel.
output_channel (int): The number of output channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of var_conv2d. If it is set to None or one attribute of ParamAttr, var_conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
dtype ('float32'): The data type of parameter and output.
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: Output variable with LoD specified by this layer.
Examples:
.. code-block:: python
import numpy as np
from paddle.fluid import layers
x_lod_tensor = layers.data(name='x', shape=[1], lod_level=1)
row_lod_tensor = layers.data(name='row', shape=[6], lod_level=1)
col_lod_tensor = layers.data(name='col', shape=[6], lod_level=1)
out = layers.var_conv_2d(input=x_lod_tensor,
row=row_lod_tensor,
col=col_lod_tensor,
input_channel=3,
output_channel=5,
filter_size=[3, 3],
stride=1)
"""
helper = LayerHelper('var_conv_2d', **locals())
x_shape = list(input.shape)
assert len(x_shape) == 2
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
stride = utils.convert_to_list(stride, 2, 'stride')
filter_shape = [
int(output_channel),
int(input_channel) * filter_size[0] * filter_size[1]
]
filter_param = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype, )
conv_res = helper.create_variable_for_type_inference(dtype)
tmp_res = helper.create_variable_for_type_inference(
dtype, stop_gradient=True)
helper.append_op(
type='var_conv_2d',
inputs={
'X': input,
'ROW': row,
'COLUMN': col,
'W': filter_param,
},
outputs={"Out": conv_res,
"Col": tmp_res},
attrs={
'InputChannel': input_channel,
'OutputChannel': output_channel,
'StrideH': stride[0],
'StrideW': stride[1],
'KernelH': filter_size[0],
'KernelW': filter_size[1],
})
return helper.append_activation(conv_res)
def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
"""
This layer creates the sharded index for input. This layers is used in

@ -74,6 +74,11 @@ if(NOT WITH_MKLML)
list(REMOVE_ITEM TEST_OPS test_fusion_seqexpand_concat_fc_op)
endif()
if(NOT WITH_MKL)
list(REMOVE_ITEM TEST_OPS test_var_conv_2d)
endif(NOT WITH_MKL)
if(WITH_GPU OR NOT WITH_MKLML)
# matmul with multiple heads need MKL support
LIST(REMOVE_ITEM TEST_OPS test_matmul_op_with_head)

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save