Fix vgg error when num_classes is given (#28557)

* fix vgg num classes
musl/fix_failed_unittests_in_musl
LielinJiang 5 years ago committed by GitHub
parent 1de3cdd0ab
commit 1c3eef4cee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -71,6 +71,9 @@ class TestVisonModels(unittest.TestCase):
def test_resnet152(self):
self.models_infer('resnet152')
def test_vgg16_num_classes(self):
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)
def test_lenet(self):
input = InputSpec([None, 1, 28, 28], 'float32', 'x')
lenet = paddle.Model(models.__dict__['LeNet'](), input)

@ -107,10 +107,7 @@ cfgs = {
def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
model = VGG(make_layers(
cfgs[cfg], batch_norm=batch_norm),
num_classes=1000,
**kwargs)
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(

Loading…
Cancel
Save