Add comment for dygraph api (#17869)

* add api commet; test=develop

* fix fc dtype bug; test=develop

* remove float32 in default parameter; test=develop

* fix exmpale bug; test=develop

* fix build once; test=develop

* fix num_chanels bug; test=develop

* fix install check failed bug; test=develop
dependabot/pip/python/requests-2.20.0
Hongyu Liu 6 years ago committed by GitHub
parent 209a3f4e09
commit 2a9d74f67c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -84,7 +84,7 @@ class Conv2D(layers.Layer):
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args: Args:
input (Variable): The input image with [N, C, H, W] format. name_scope(str) : The name for this class.
num_filters(int): The number of filter. It is as same as the output num_filters(int): The number of filter. It is as same as the output
image channel. image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple, filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
@ -118,12 +118,6 @@ class Conv2D(layers.Layer):
library is installed. Default: True library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended. act (str): Activation type, if it is set to None, activation is not appended.
Default: None Default: None
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically. Default: None
Returns:
Variable: The tensor variable storing the convolution and \
non-linearity activation result.
Raises: Raises:
ValueError: If the shapes of input, filter_size, stride, padding and ValueError: If the shapes of input, filter_size, stride, padding and
@ -131,25 +125,37 @@ class Conv2D(layers.Layer):
Examples: Examples:
.. code-block:: python .. code-block:: python
with fluid.dygraph.guard():
conv2d = Conv2D( "conv2d", 2, 3)
data = to_variable( data )
conv = conv2d( data )
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D
import numpy as np
data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32')
with fluid.dygraph.guard():
conv2d = Conv2D( "conv2d", 2, 3)
data = to_variable( data )
conv = conv2d( data )
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
conv2d = fluid.layers.conv2d(input=data, num_filters=2, filter_size=3, act="relu")
""" """
def __init__(self, def __init__(self,
name_scope, name_scope,
num_channels,
num_filters, num_filters,
filter_size, filter_size,
stride=1, stride=1,
padding=0, padding=0,
dilation=1, dilation=1,
groups=None, groups=None,
use_cudnn=True,
act=None,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
dtype=core.VarDesc.VarType.FP32): use_cudnn=True,
act=None,
dtype='float32'):
assert param_attr is not False, "param_attr should not be False here." assert param_attr is not False, "param_attr should not be False here."
super(Conv2D, self).__init__(name_scope, dtype) super(Conv2D, self).__init__(name_scope, dtype)
self._groups = groups self._groups = groups
@ -160,7 +166,11 @@ class Conv2D(layers.Layer):
if not isinstance(use_cudnn, bool): if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False") raise ValueError("use_cudnn should be True or False")
self._use_cudnn = use_cudnn self._use_cudnn = use_cudnn
self._num_channels = num_channels self._filter_size = filter_size
self._num_filters = num_filters
self._param_attr = param_attr
self._bias_attr = bias_attr
self._dtype = dtype
# if (self._num_channels == self._groups and # if (self._num_channels == self._groups and
# num_filters % self._num_channels == 0 and not self._use_cudnn): # num_filters % self._num_channels == 0 and not self._use_cudnn):
# self._l_type = 'depthwise_conv2d' # self._l_type = 'depthwise_conv2d'
@ -169,22 +179,26 @@ class Conv2D(layers.Layer):
# kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275 # kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275
self._l_type = 'conv2d' self._l_type = 'conv2d'
if groups is None: def _build_once(self, input):
num_filter_channels = num_channels self._num_channels = input.shape[1]
if self._groups is None:
num_filter_channels = self._num_channels
else: else:
if num_channels % groups != 0: if self._num_channels % self._groups != 0:
raise ValueError("num_channels must be divisible by groups.") raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = num_channels // groups num_filter_channels = self._num_channels // self._groups
filter_size = utils.convert_to_list(filter_size, 2, 'filter_size') filter_size = utils.convert_to_list(self._filter_size, 2, 'filter_size')
filter_shape = [num_filters, int(num_filter_channels)] + filter_size filter_shape = [self._num_filters, int(num_filter_channels)
] + filter_size
def _get_default_param_initializer(): def _get_default_param_initializer():
filter_elem_num = filter_size[0] * filter_size[1] * num_channels filter_elem_num = filter_size[0] * filter_size[
1] * self._num_channels
std = (2.0 / filter_elem_num)**0.5 std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0) return Normal(0.0, std, 0)
self._filter_param = self.create_parameter( self._filter_param = self.create_parameter(
attr=param_attr, attr=self._param_attr,
shape=filter_shape, shape=filter_shape,
dtype=self._dtype, dtype=self._dtype,
default_initializer=_get_default_param_initializer()) default_initializer=_get_default_param_initializer())
@ -204,8 +218,8 @@ class Conv2D(layers.Layer):
type=core.VarDesc.VarType.RAW) type=core.VarDesc.VarType.RAW)
self._bias_param = self.create_parameter( self._bias_param = self.create_parameter(
attr=bias_attr, attr=self._bias_attr,
shape=[num_filters], shape=[self._num_filters],
dtype=self._dtype, dtype=self._dtype,
is_bias=True) is_bias=True)
@ -653,14 +667,12 @@ class Conv3DTranspose(layers.Layer):
class Pool2D(layers.Layer): class Pool2D(layers.Layer):
# TODO, should delete this class
""" """
${comment} ${comment}
Args: Args:
input (Variable): The input tensor of pooling operator. The format of name_scope(str) : The name of this class.
input tensor is NCHW, where N is batch size, C is
the number of channels, H is the height of the
feature, and W is the width of the feature.
pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list, pool_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two integers, (pool_size_Height, pool_size_Width). it must contain two integers, (pool_size_Height, pool_size_Width).
Otherwise, the pool kernel size will be a square of an int. Otherwise, the pool kernel size will be a square of an int.
@ -814,8 +826,7 @@ class FC(layers.Layer):
out.shape = (1, 2) out.shape = (1, 2)
Args: Args:
input (Variable|list of Variable): The input tensor(s) of this layer, and the dimension of name(str): The name of this class.
the input tensor(s) is at least 2.
size(int): The number of output units in this layer. size(int): The number of output units in this layer.
num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than num_flatten_dims (int, default 1): The fc layer can accept an input tensor with more than
two dimensions. If this happens, the multidimensional tensor will first be flattened two dimensions. If this happens, the multidimensional tensor will first be flattened
@ -833,37 +844,35 @@ class FC(layers.Layer):
If it is set to None, the bias is initialized zero. Default: None. If it is set to None, the bias is initialized zero. Default: None.
act (str, default None): Activation to be applied to the output of this layer. act (str, default None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase. is_test(bool): A flag indicating whether execution is in test phase.
name (str, default None): The name of this layer. dtype(str): Dtype used for weight
Returns:
Variable: The transformation result.
Raises: Raises:
ValueError: If rank of the input tensor is less than 2. ValueError: If rank of the input tensor is less than 2.
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32')
with fluid.dygraph.guard():
fc = FC( "fc", 64, num_flatten_dims=2)
data = to_variable( data )
conv = fc( data )
# when input is single tensor
data = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
fc = fluid.FC("fc", size=1000, act="tanh")
fc_res = fc(data)
# when input are multiple tensors
data_1 = fluid.layers.data(name="data_1", shape=[32, 32], dtype="float32")
data_2 = fluid.layers.data(name="data_2", shape=[24, 36], dtype="float32")
fc = fluid.FC("fc", size=1000, act="tanh")
fc_res = fc([data_1, data_2])
""" """
def __init__(self, def __init__(self,
name_scope, name_scope,
size, size,
num_flatten_dims=1,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
num_flatten_dims=1, act=None,
dtype=core.VarDesc.VarType.FP32, is_test=False,
act=None): dtype="float32"):
super(FC, self).__init__(name_scope, dtype) super(FC, self).__init__(name_scope, dtype)
self._size = size self._size = size
@ -1048,7 +1057,7 @@ class BatchNorm(layers.Layer):
epsilon=1e-05, epsilon=1e-05,
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
dtype=core.VarDesc.VarType.FP32, dtype='float32',
data_layout='NCHW', data_layout='NCHW',
in_place=False, in_place=False,
moving_mean_name=None, moving_mean_name=None,
@ -1064,8 +1073,8 @@ class BatchNorm(layers.Layer):
assert bias_attr is not False, "bias_attr should not be False in batch_norm." assert bias_attr is not False, "bias_attr should not be False in batch_norm."
if dtype == core.VarDesc.VarType.FP16: if dtype == "float16":
self._dtype = core.VarDesc.VarType.FP32 self._dtype = "float32"
else: else:
self._dtype = dtype self._dtype = dtype
@ -1444,6 +1453,7 @@ class GRUUnit(layers.Layer):
Default: 'tanh' Default: 'tanh'
gate_activation (string): The activation type for gates (actGate). gate_activation (string): The activation type for gates (actGate).
Default: 'sigmoid' Default: 'sigmoid'
dtype(string): The dtype of the layers
Returns: Returns:
tuple: The hidden value, reset-hidden value and gate values. tuple: The hidden value, reset-hidden value and gate values.

@ -31,7 +31,7 @@ class SimpleLayer(Layer):
super(SimpleLayer, self).__init__(name_scope) super(SimpleLayer, self).__init__(name_scope)
self._fc1 = nn.FC(self.full_name(), self._fc1 = nn.FC(self.full_name(),
3, 3,
ParamAttr(initializer=Constant(value=0.1))) param_attr=ParamAttr(initializer=Constant(value=0.1)))
def forward(self, inputs): def forward(self, inputs):
x = self._fc1(inputs) x = self._fc1(inputs)

@ -55,7 +55,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
self._conv2d = Conv2D( self._conv2d = Conv2D(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=conv_stride, stride=conv_stride,

@ -47,7 +47,6 @@ class ConvBNLayer(fluid.dygraph.Layer):
self._conv = Conv2D( self._conv = Conv2D(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,

@ -51,7 +51,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
self._conv2d = Conv2D( self._conv2d = Conv2D(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=conv_stride, stride=conv_stride,

@ -25,7 +25,6 @@ from paddle.fluid.dygraph.base import to_variable
class SimpleImgConvPool(fluid.Layer): class SimpleImgConvPool(fluid.Layer):
def __init__(self, def __init__(self,
name_scope, name_scope,
num_channels,
num_filters, num_filters,
filter_size, filter_size,
pool_size, pool_size,
@ -45,7 +44,6 @@ class SimpleImgConvPool(fluid.Layer):
self._conv2d = Conv2D( self._conv2d = Conv2D(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=conv_stride, stride=conv_stride,
@ -76,10 +74,10 @@ class MNIST(fluid.Layer):
super(MNIST, self).__init__(name_scope) super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool( self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu") self.full_name(), 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool( self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu") self.full_name(), 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4 pool_2_shape = 50 * 4 * 4
SIZE = 10 SIZE = 10

@ -31,7 +31,6 @@ from test_imperative_base import new_program_scope
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
name_scope, name_scope,
num_channels,
num_filters, num_filters,
filter_size, filter_size,
pool_size, pool_size,
@ -51,7 +50,6 @@ class SimpleImgConvPool(fluid.dygraph.Layer):
self._conv2d = Conv2D( self._conv2d = Conv2D(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=conv_stride, stride=conv_stride,
@ -82,10 +80,10 @@ class MNIST(fluid.dygraph.Layer):
super(MNIST, self).__init__(name_scope) super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool( self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu") self.full_name(), 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool( self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu") self.full_name(), 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4 pool_2_shape = 50 * 4 * 4
SIZE = 10 SIZE = 10

@ -80,7 +80,6 @@ class ConvBNPool(fluid.dygraph.Layer):
self.conv_0_layer = Conv2D( self.conv_0_layer = Conv2D(
self.full_name(), self.full_name(),
channels[0],
out_ch[0], out_ch[0],
3, 3,
padding=1, padding=1,
@ -92,7 +91,6 @@ class ConvBNPool(fluid.dygraph.Layer):
self.full_name(), out_ch[0], act=act, is_test=is_test) self.full_name(), out_ch[0], act=act, is_test=is_test)
self.conv_1_layer = Conv2D( self.conv_1_layer = Conv2D(
self.full_name(), self.full_name(),
num_channels=channels[1],
num_filters=out_ch[1], num_filters=out_ch[1],
filter_size=3, filter_size=3,
padding=1, padding=1,

@ -71,7 +71,6 @@ def optimizer_setting(params):
class ConvBNLayer(fluid.Layer): class ConvBNLayer(fluid.Layer):
def __init__(self, def __init__(self,
name_scope, name_scope,
num_channels,
num_filters, num_filters,
filter_size, filter_size,
stride=1, stride=1,
@ -81,7 +80,6 @@ class ConvBNLayer(fluid.Layer):
self._conv = Conv2D( self._conv = Conv2D(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
@ -100,30 +98,22 @@ class ConvBNLayer(fluid.Layer):
class BottleneckBlock(fluid.Layer): class BottleneckBlock(fluid.Layer):
def __init__(self, def __init__(self, name_scope, num_filters, stride, shortcut=True):
name_scope,
num_channels,
num_filters,
stride,
shortcut=True):
super(BottleneckBlock, self).__init__(name_scope) super(BottleneckBlock, self).__init__(name_scope)
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
act='relu') act='relu')
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=num_filters,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
stride=stride, stride=stride,
act='relu') act='relu')
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=num_filters,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
act=None) act=None)
@ -131,15 +121,12 @@ class BottleneckBlock(fluid.Layer):
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
stride=stride) stride=stride)
self.shortcut = shortcut self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs): def forward(self, inputs):
y = self.conv0(inputs) y = self.conv0(inputs)
conv1 = self.conv1(y) conv1 = self.conv1(y)
@ -175,7 +162,6 @@ class ResNet(fluid.Layer):
self.conv = ConvBNLayer( self.conv = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=3,
num_filters=64, num_filters=64,
filter_size=7, filter_size=7,
stride=2, stride=2,
@ -188,7 +174,6 @@ class ResNet(fluid.Layer):
pool_type='max') pool_type='max')
self.bottleneck_block_list = [] self.bottleneck_block_list = []
num_channels = 64
for block in range(len(depth)): for block in range(len(depth)):
shortcut = False shortcut = False
for i in range(depth[block]): for i in range(depth[block]):
@ -196,11 +181,9 @@ class ResNet(fluid.Layer):
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BottleneckBlock( BottleneckBlock(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters[block], num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1, stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut)) shortcut=shortcut))
num_channels = bottleneck_block._num_channels_out
self.bottleneck_block_list.append(bottleneck_block) self.bottleneck_block_list.append(bottleneck_block)
shortcut = True shortcut = True

@ -64,7 +64,6 @@ def optimizer_setting(params):
class ConvBNLayer(fluid.dygraph.Layer): class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self, def __init__(self,
name_scope, name_scope,
num_channels,
num_filters, num_filters,
filter_size, filter_size,
stride=1, stride=1,
@ -74,7 +73,6 @@ class ConvBNLayer(fluid.dygraph.Layer):
self._conv = Conv2D( self._conv = Conv2D(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=filter_size, filter_size=filter_size,
stride=stride, stride=stride,
@ -131,20 +129,15 @@ class BottleneckBlock(fluid.dygraph.Layer):
super(BottleneckBlock, self).__init__(name_scope) super(BottleneckBlock, self).__init__(name_scope)
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
self.full_name(), self.full_name(), num_filters=num_filters, filter_size=1)
num_channels=num_channels,
num_filters=num_filters,
filter_size=1)
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=num_filters,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
stride=stride, stride=stride,
groups=cardinality) groups=cardinality)
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=num_filters,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
act='relu') act='relu')
@ -157,7 +150,6 @@ class BottleneckBlock(fluid.dygraph.Layer):
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=num_channels,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
stride=stride) stride=stride)
@ -200,7 +192,6 @@ class SeResNeXt(fluid.dygraph.Layer):
num_filters = [128, 256, 512, 1024] num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=3,
num_filters=64, num_filters=64,
filter_size=7, filter_size=7,
stride=2, stride=2,
@ -218,7 +209,6 @@ class SeResNeXt(fluid.dygraph.Layer):
num_filters = [128, 256, 512, 1024] num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=3,
num_filters=3, num_filters=3,
filter_size=7, filter_size=7,
stride=2, stride=2,
@ -236,21 +226,18 @@ class SeResNeXt(fluid.dygraph.Layer):
num_filters = [128, 256, 512, 1024] num_filters = [128, 256, 512, 1024]
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=3,
num_filters=3, num_filters=3,
filter_size=7, filter_size=7,
stride=2, stride=2,
act='relu') act='relu')
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=64,
num_filters=3, num_filters=3,
filter_size=7, filter_size=7,
stride=2, stride=2,
act='relu') act='relu')
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
self.full_name(), self.full_name(),
num_channels=64,
num_filters=3, num_filters=3,
filter_size=7, filter_size=7,
stride=2, stride=2,

@ -20,3 +20,7 @@ import paddle.fluid as fluid
class TestInstallCheck(unittest.TestCase): class TestInstallCheck(unittest.TestCase):
def test_install_check(self): def test_install_check(self):
fluid.install_check.run_check() fluid.install_check.run_check()
if __name__ == '__main__':
unittest.main()

@ -190,8 +190,7 @@ class TestLayer(LayerTest):
with self.static_graph(): with self.static_graph():
images = layers.data(name='pixel', shape=[3, 5, 5], dtype='float32') images = layers.data(name='pixel', shape=[3, 5, 5], dtype='float32')
conv2d = nn.Conv2D( conv2d = nn.Conv2D('conv2d', num_filters=3, filter_size=[2, 2])
'conv2d', num_channels=3, num_filters=3, filter_size=[2, 2])
ret = conv2d(images) ret = conv2d(images)
static_ret2 = self.get_static_graph_result( static_ret2 = self.get_static_graph_result(
feed={'pixel': np.ones( feed={'pixel': np.ones(
@ -200,8 +199,7 @@ class TestLayer(LayerTest):
with self.dynamic_graph(): with self.dynamic_graph():
images = np.ones([2, 3, 5, 5], dtype='float32') images = np.ones([2, 3, 5, 5], dtype='float32')
conv2d = nn.Conv2D( conv2d = nn.Conv2D('conv2d', num_filters=3, filter_size=[2, 2])
'conv2d', num_channels=3, num_filters=3, filter_size=[2, 2])
dy_ret = conv2d(base.to_variable(images)) dy_ret = conv2d(base.to_variable(images))
self.assertTrue(np.allclose(static_ret, dy_ret.numpy())) self.assertTrue(np.allclose(static_ret, dy_ret.numpy()))

Loading…
Cancel
Save