|
|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""VGG."""
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.common.initializer import initializer
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
|
|
|
|
|
@ -63,8 +62,7 @@ class Vgg(nn.Cell):
|
|
|
|
|
def __init__(self, base, num_classes=1000, batch_norm=False, batch_size=1):
|
|
|
|
|
super(Vgg, self).__init__()
|
|
|
|
|
self.layers = _make_layer(base, batch_norm=batch_norm)
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
self.shp = (batch_size, -1)
|
|
|
|
|
self.flatten = nn.Flatten()
|
|
|
|
|
self.classifier = nn.SequentialCell([
|
|
|
|
|
nn.Dense(512 * 7 * 7, 4096),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
@ -74,7 +72,7 @@ class Vgg(nn.Cell):
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
x = self.layers(x)
|
|
|
|
|
x = self.reshape(x, self.shp)
|
|
|
|
|
x = self.flatten(x)
|
|
|
|
|
x = self.classifier(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
@ -87,20 +85,19 @@ cfg = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def vgg16(batch_size=1, num_classes=1000):
|
|
|
|
|
def vgg16(num_classes=1000):
|
|
|
|
|
"""
|
|
|
|
|
Get Vgg16 neural network with batch normalization.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
batch_size (int): Batch size. Default: 1.
|
|
|
|
|
num_classes (int): Class numbers. Default: 1000.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Cell, cell instance of Vgg16 neural network with batch normalization.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> vgg16(batch_size=1, num_classes=1000)
|
|
|
|
|
>>> vgg16(num_classes=1000)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True, batch_size=batch_size)
|
|
|
|
|
net = Vgg(cfg['16'], num_classes=num_classes, batch_norm=True)
|
|
|
|
|
return net
|
|
|
|
|
|