Add comment for dygraph api (#17869)

* add api commet; test=develop

* fix fc dtype bug; test=develop

* remove float32 in default parameter; test=develop

* fix exmpale bug; test=develop

* fix build once; test=develop

* fix num_chanels bug; test=develop

* fix install check failed bug; test=develop
dependabot/pip/python/requests-2.20.0
Hongyu Liu 6 years ago committed by GitHub
parent 209a3f4e09
commit 2a9d74f67c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -84,7 +84,7 @@ class Conv2D(layers.Layer):
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
input (Variable): The input image with [N, C, H, W] format.
name_scope(str) : The name for this class.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
@ -118,12 +118,6 @@ class Conv2D(layers.Layer):
library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: The tensor variable storing the convolution and \
non-linearity activation result.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
@ -132,24 +126,36 @@ class Conv2D(layers.Layer):
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu")
with fluid.dygraph.guard():
conv2d = Conv2D( "conv2d", 2, 3)
data = to_variable( data )
conv = conv2d( data )
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D
import numpy as np
data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32')
with fluid.dygraph.guard():
conv2d = Conv2D( "conv2d", 2, 3)
data = to_variable( data )
conv = conv2d( data )
"""
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
use_cudnn=True,
act=None,
param_attr=None,
bias_attr=None,
dtype=core.VarDesc.VarType.FP32):
use_cudnn=True,
act=None,
dtype='float32'):
assert param_attr is not False, "param_attr should not be False here."
super(Conv2D, self).__init__(name_scope, dtype)
self._groups = groups
@ -160,7 +166,11 @@ class Conv2D(layers.Layer):
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
self._use_cudnn = use_cudnn
self._num_channels = num_channels
self._filter_size = filter_size
self._num_filters = num_filters
self._param_attr = param_attr
self._bias_attr = bias_attr
self._dtype = dtype
# if (self._num_channels == self._groups and
# num_filters % self._num_channels == 0 and not self._use_cudnn):
# self._l_type = 'depthwise_conv2d'
@ -169,22 +179,26 @@ class Conv2D(layers.Layer):
# kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275
self._l_type = 'conv2d'
if groups is None:
num_filter_channels = num_channels
def _build_once(self, input):
self._num_channels = input.shape[1]
if self._groups is None:
num_filter_channels = self._num_channels
else:
if num_channels % groups != 0:
if self._num_channels % self._groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels // groups
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size')
filter_shape = [num_filters, int(num_filter_channels)] + filter_size
num_filter_channels = self._num_channels // self._groups
filter_size = utils.convert_to_list(self._filter_size, 2, 'filter_size')
filter_shape = [self._num_filters, int(num_filter_channels)
] + filter_size
def _get_default_param_initializer():
filter_elem_num = filter_size[0] * filter_size[1] * num_channels
filter_elem_num = filter_size[0] * filter_size[
1] * self._num_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
self._filter_param = self.create_parameter(
attr=param_attr,
attr=self._param_attr,
shape=filter_shape,
dtype=self._dtype,
default_initializer=_get_default_param_initializer())
@ -204,8 +218,8 @@ class Conv2D(layers.Layer):
type=core.VarDesc.VarType.RAW)
self._bias_param = self.create_parameter(
attr=bias_attr,
shape=[num_filters],
attr=self._bias_attr,
shape=[self._num_filters],
dtype=self._dtype,
is_bias=True)
@ -653,14 +667,12 @@ class Conv3DTranspose(layers.Layer):
class Pool2D(layers.Layer):
# TODO, should delete this class
"""
${comment}
Args:
input (Variable): The input tensor of pooling operator. The format of
input tensor is NCHW, where N is batch size, C is
the number of channels, H is the height of the
feature, and W is the width of the feature.
name_scope(str) : The name of this class.
pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two integers, (pool_size_Height, pool_size_Width).
Otherwise, the pool kernel size will be a square of an int.
@ -814,8 +826,7 @@ class FC(layers.Layer):
out.shape = (1, 2)
Args:
input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of
the input tensor(s) is at least 2.
name(str): The name of this class.
size(int): The number of output units in this layer.
num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than
two dimensions. If this happens, the multidimensional tensor will first be flattened
@ -833,10 +844,7 @@ class FC(layers.Layer):
If it is set to None, the bias is initialized zero. Default: None.
act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase.
name (str, default None): The name of this layer.
Returns:
Variable: The transformation result.
dtype(str): Dtype used for weight
Raises:
ValueError: If rank of the input tensor is less than 2.
@ -844,26 +852,27 @@ class FC(layers.Layer):
Examples:
.. code-block:: python
# when input is single tensor
data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
fc = fluid.FC("fc", size=1000, act="tanh")
fc_res = fc(data)
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32')
with fluid.dygraph.guard():
fc = FC( "fc", 64, num_flatten_dims=2)
data = to_variable( data )
conv = fc( data )
# when input are multiple tensors
data_1 = fluid.layers.data(name="data_1", shape=[32, 32], dtype="float32")
data_2 = fluid.layers.data(name="data_2", shape=[24, 36], dtype="float32")
fc = fluid.FC("fc", size=1000, act="tanh")
fc_res = fc([data_1, data_2])
"""
def __init__(self,
name_scope,
size,
num_flatten_dims=1,
param_attr=None,
bias_attr=None,
num_flatten_dims=1,
dtype=core.VarDesc.VarType.FP32,
act=None):
act=None,
is_test=False,
dtype="float32"):
super(FC, self).__init__(name_scope, dtype)
self._size = size
@ -1048,7 +1057,7 @@ class BatchNorm(layers.Layer):
epsilon=1e-05,
param_attr=None,
bias_attr=None,
dtype=core.VarDesc.VarType.FP32,
dtype='float32',
data_layout='NCHW',
in_place=False,
moving_mean_name=None,
@ -1064,8 +1073,8 @@ class BatchNorm(layers.Layer):
assert bias_attr is not False, "bias_attr should not be False in batch_norm."
if dtype == core.VarDesc.VarType.FP16:
self._dtype = core.VarDesc.VarType.FP32
if dtype == "float16":
self._dtype = "float32"
else:
self._dtype = dtype
@ -1444,6 +1453,7 @@ class GRUUnit(layers.Layer):
Default: 'tanh'
gate_activation (string): The activation type for gates (actGate).
Default: 'sigmoid'
dtype(string): The dtype of the layers
Returns:
tuple: The hidden value, reset-hidden value and gate values.

@ -31,7 +31,7 @@ class SimpleLayer(Layer):
super(SimpleLayer, self).__init__(name_scope)
self._fc1 = nn.FC(self.full_name(),
3,
ParamAttr(initializer=Constant(value=0.1)))
param_attr=ParamAttr(initializer=Constant(value=0.1)))
def forward(self, inputs):
x = self._fc1(inputs)

@ -55,7 +55,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,

@ -47,7 +47,6 @@ class ConvBNLayer(fluid.dygraph.Layer):
self._conv = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,

@ -51,7 +51,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,

@ -25,7 +25,6 @@ from paddle.fluid.dygraph.base import to_variable
class SimpleImgConvPool(fluid.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
pool_size,
@ -45,7 +44,6 @@ class SimpleImgConvPool(fluid.Layer):
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
@ -76,10 +74,10 @@ class MNIST(fluid.Layer):
super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
self.full_name(), 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
self.full_name(), 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10

@ -31,7 +31,6 @@ from test_imperative_base import new_program_scope
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
pool_size,
@ -51,7 +50,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
@ -82,10 +80,10 @@ class MNIST(fluid.dygraph.Layer):
super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
self.full_name(), 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
self.full_name(), 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10

@ -80,7 +80,6 @@ class ConvBNPool(fluid.dygraph.Layer):
self.conv_0_layer = Conv2D(
self.full_name(),
channels[0],
out_ch[0],
3,
padding=1,
@ -92,7 +91,6 @@ class ConvBNPool(fluid.dygraph.Layer):
self.full_name(), out_ch[0], act=act, is_test=is_test)
self.conv_1_layer = Conv2D(
self.full_name(),
num_channels=channels[1],
num_filters=out_ch[1],
filter_size=3,
padding=1,

@ -71,7 +71,6 @@ def optimizer_setting(params):
class ConvBNLayer(fluid.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
stride=1,
@ -81,7 +80,6 @@ class ConvBNLayer(fluid.Layer):
self._conv = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
@ -100,30 +98,22 @@ class ConvBNLayer(fluid.Layer):
class BottleneckBlock(fluid.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
stride,
shortcut=True):
def __init__(self, name_scope, num_filters, stride, shortcut=True):
super(BottleneckBlock, self).__init__(name_scope)
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
self.full_name(),
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
self.full_name(),
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
@ -131,15 +121,12 @@ class BottleneckBlock(fluid.Layer):
if not shortcut:
self.short = ConvBNLayer(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
@ -175,7 +162,6 @@ class ResNet(fluid.Layer):
self.conv = ConvBNLayer(
self.full_name(),
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
@ -188,7 +174,6 @@ class ResNet(fluid.Layer):
pool_type='max')
self.bottleneck_block_list = []
num_channels = 64
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
@ -196,11 +181,9 @@ class ResNet(fluid.Layer):
'bb_%d_%d' % (block, i),
BottleneckBlock(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut))
num_channels = bottleneck_block._num_channels_out
self.bottleneck_block_list.append(bottleneck_block)
shortcut = True

@ -64,7 +64,6 @@ def optimizer_setting(params):
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
stride=1,
@ -74,7 +73,6 @@ class ConvBNLayer(fluid.dygraph.Layer):
self._conv = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
@ -131,20 +129,15 @@ class BottleneckBlock(fluid.dygraph.Layer):
super(BottleneckBlock, self).__init__(name_scope)
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=1)
self.full_name(), num_filters=num_filters, filter_size=1)
self.conv1 = ConvBNLayer(
self.full_name(),
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
groups=cardinality)
self.conv2 = ConvBNLayer(
self.full_name(),
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act='relu')
@ -157,7 +150,6 @@ class BottleneckBlock(fluid.dygraph.Layer):
if not shortcut:
self.short = ConvBNLayer(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
@ -200,7 +192,6 @@ class SeResNeXt(fluid.dygraph.Layer):
num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
@ -218,7 +209,6 @@ class SeResNeXt(fluid.dygraph.Layer):
num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=3,
num_filters=3,
filter_size=7,
stride=2,
@ -236,21 +226,18 @@ class SeResNeXt(fluid.dygraph.Layer):
num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=3,
num_filters=3,
filter_size=7,
stride=2,
act='relu')
self.conv1 = ConvBNLayer(
self.full_name(),
num_channels=64,
num_filters=3,
filter_size=7,
stride=2,
act='relu')
self.conv2 = ConvBNLayer(
self.full_name(),
num_channels=64,
num_filters=3,
filter_size=7,
stride=2,

@ -20,3 +20,7 @@ import paddle.fluid as fluid
class TestInstallCheck(unittest.TestCase):
def test_install_check(self):
fluid.install_check.run_check()
if __name__ == '__main__':
unittest.main()

@ -190,8 +190,7 @@ class TestLayer(LayerTest):
with self.static_graph():
images = layers.data(name='pixel', shape=[3, 5, 5], dtype='float32')
conv2d = nn.Conv2D(
'conv2d', num_channels=3, num_filters=3, filter_size=[2, 2])
conv2d = nn.Conv2D('conv2d', num_filters=3, filter_size=[2, 2])
ret = conv2d(images)
static_ret2 = self.get_static_graph_result(
feed={'pixel': np.ones(
@ -200,8 +199,7 @@ class TestLayer(LayerTest):
with self.dynamic_graph():
images = np.ones([2, 3, 5, 5], dtype='float32')
conv2d = nn.Conv2D(
'conv2d', num_channels=3, num_filters=3, filter_size=[2, 2])
conv2d = nn.Conv2D('conv2d', num_filters=3, filter_size=[2, 2])
dy_ret = conv2d(base.to_variable(images))
self.assertTrue(np.allclose(static_ret, dy_ret.numpy()))

Loading…
Cancel
Save