Add with_pool args for vgg (#28684)

* add arg for vgg
musl/fix_failed_unittests_in_musl
LielinJiang 4 years ago committed by GitHub
parent 532e4bbf2a
commit 01a14e1be2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -245,7 +245,7 @@ class ResNet(nn.Layer):
x = self.layer3(x)
x = self.layer4(x)
if self.with_pool > 0:
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:

@ -36,9 +36,10 @@ class VGG(nn.Layer):
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
Args:
features (nn.Layer): vgg features create by function make_layers.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
features (nn.Layer): Vgg features create by function make_layers.
num_classes (int): Output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
with_pool (bool): Use pool before the last three fc layer or not. Default: True.
Examples:
.. code-block:: python
@ -54,24 +55,35 @@ class VGG(nn.Layer):
"""
def __init__(self, features, num_classes=1000):
def __init__(self, features, num_classes=1000, with_pool=True):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2D((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes), )
self.num_classes = num_classes
self.with_pool = with_pool
if with_pool:
self.avgpool = nn.AdaptiveAvgPool2D((7, 7))
if num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes), )
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = paddle.flatten(x, 1)
x = self.classifier(x)
if self.with_pool:
x = self.avgpool(x)
if self.num_classes > 0:
x = paddle.flatten(x, 1)
x = self.classifier(x)
return x

Loading…
Cancel
Save