You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
231 lines
6.2 KiB
231 lines
6.2 KiB
#!/usr/bin/env python
|
|
from paddle.trainer_config_helpers import *
|
|
|
|
height = 224
|
|
width = 224
|
|
num_class = 1000
|
|
batch_size = get_config_arg('batch_size', int, 64)
|
|
layer_num = get_config_arg("layer_num", int, 50)
|
|
is_infer = get_config_arg("is_infer", bool, False)
|
|
num_samples = get_config_arg('num_samples', int, 2560)
|
|
|
|
args = {
|
|
'height': height,
|
|
'width': width,
|
|
'color': True,
|
|
'num_class': num_class,
|
|
'is_infer': is_infer,
|
|
'num_samples': num_samples
|
|
}
|
|
define_py_data_sources2(
|
|
"train.list" if not is_infer else None,
|
|
"test.list" if is_infer else None,
|
|
module="provider",
|
|
obj="process",
|
|
args=args)
|
|
|
|
settings(
|
|
batch_size=batch_size,
|
|
learning_rate=0.01 / batch_size,
|
|
learning_method=MomentumOptimizer(0.9),
|
|
regularization=L2Regularization(0.0005 * batch_size))
|
|
|
|
|
|
#######################Network Configuration #############
|
|
def conv_bn_layer(name,
|
|
input,
|
|
filter_size,
|
|
num_filters,
|
|
stride,
|
|
padding,
|
|
channels=None,
|
|
active_type=ReluActivation()):
|
|
"""
|
|
A wrapper for conv layer with batch normalization layers.
|
|
Note:
|
|
conv layer has no activation.
|
|
"""
|
|
|
|
tmp = img_conv_layer(
|
|
name=name + "_conv",
|
|
input=input,
|
|
filter_size=filter_size,
|
|
num_channels=channels,
|
|
num_filters=num_filters,
|
|
stride=stride,
|
|
padding=padding,
|
|
act=LinearActivation(),
|
|
bias_attr=False)
|
|
return batch_norm_layer(
|
|
name=name + "_bn",
|
|
input=tmp,
|
|
act=active_type,
|
|
use_global_stats=is_infer)
|
|
|
|
|
|
def bottleneck_block(name, input, num_filters1, num_filters2):
|
|
"""
|
|
A wrapper for bottlenect building block in ResNet.
|
|
Last conv_bn_layer has no activation.
|
|
Addto layer has activation of relu.
|
|
"""
|
|
last_name = conv_bn_layer(
|
|
name=name + '_branch2a',
|
|
input=input,
|
|
filter_size=1,
|
|
num_filters=num_filters1,
|
|
stride=1,
|
|
padding=0)
|
|
last_name = conv_bn_layer(
|
|
name=name + '_branch2b',
|
|
input=last_name,
|
|
filter_size=3,
|
|
num_filters=num_filters1,
|
|
stride=1,
|
|
padding=1)
|
|
last_name = conv_bn_layer(
|
|
name=name + '_branch2c',
|
|
input=last_name,
|
|
filter_size=1,
|
|
num_filters=num_filters2,
|
|
stride=1,
|
|
padding=0,
|
|
active_type=LinearActivation())
|
|
|
|
return addto_layer(
|
|
name=name + "_addto", input=[input, last_name], act=ReluActivation())
|
|
|
|
|
|
def mid_projection(name, input, num_filters1, num_filters2, stride=2):
|
|
"""
|
|
A wrapper for middile projection in ResNet.
|
|
projection shortcuts are used for increasing dimensions,
|
|
and other shortcuts are identity
|
|
branch1: projection shortcuts are used for increasing
|
|
dimensions, has no activation.
|
|
branch2x: bottleneck building block, shortcuts are identity.
|
|
"""
|
|
# stride = 2
|
|
branch1 = conv_bn_layer(
|
|
name=name + '_branch1',
|
|
input=input,
|
|
filter_size=1,
|
|
num_filters=num_filters2,
|
|
stride=stride,
|
|
padding=0,
|
|
active_type=LinearActivation())
|
|
|
|
last_name = conv_bn_layer(
|
|
name=name + '_branch2a',
|
|
input=input,
|
|
filter_size=1,
|
|
num_filters=num_filters1,
|
|
stride=stride,
|
|
padding=0)
|
|
last_name = conv_bn_layer(
|
|
name=name + '_branch2b',
|
|
input=last_name,
|
|
filter_size=3,
|
|
num_filters=num_filters1,
|
|
stride=1,
|
|
padding=1)
|
|
|
|
last_name = conv_bn_layer(
|
|
name=name + '_branch2c',
|
|
input=last_name,
|
|
filter_size=1,
|
|
num_filters=num_filters2,
|
|
stride=1,
|
|
padding=0,
|
|
active_type=LinearActivation())
|
|
|
|
return addto_layer(
|
|
name=name + "_addto", input=[branch1, last_name], act=ReluActivation())
|
|
|
|
|
|
img = data_layer(name='image', size=height * width * 3)
|
|
|
|
|
|
def deep_res_net(res2_num=3, res3_num=4, res4_num=6, res5_num=3):
|
|
"""
|
|
A wrapper for 50,101,152 layers of ResNet.
|
|
res2_num: number of blocks stacked in conv2_x
|
|
res3_num: number of blocks stacked in conv3_x
|
|
res4_num: number of blocks stacked in conv4_x
|
|
res5_num: number of blocks stacked in conv5_x
|
|
"""
|
|
# For ImageNet
|
|
# conv1: 112x112
|
|
tmp = conv_bn_layer(
|
|
"conv1",
|
|
input=img,
|
|
filter_size=7,
|
|
channels=3,
|
|
num_filters=64,
|
|
stride=2,
|
|
padding=3)
|
|
tmp = img_pool_layer(name="pool1", input=tmp, pool_size=3, stride=2)
|
|
|
|
# conv2_x: 56x56
|
|
tmp = mid_projection(
|
|
name="res2_1", input=tmp, num_filters1=64, num_filters2=256, stride=1)
|
|
for i in xrange(2, res2_num + 1, 1):
|
|
tmp = bottleneck_block(
|
|
name="res2_" + str(i), input=tmp, num_filters1=64, num_filters2=256)
|
|
|
|
# conv3_x: 28x28
|
|
tmp = mid_projection(
|
|
name="res3_1", input=tmp, num_filters1=128, num_filters2=512)
|
|
for i in xrange(2, res3_num + 1, 1):
|
|
tmp = bottleneck_block(
|
|
name="res3_" + str(i),
|
|
input=tmp,
|
|
num_filters1=128,
|
|
num_filters2=512)
|
|
|
|
# conv4_x: 14x14
|
|
tmp = mid_projection(
|
|
name="res4_1", input=tmp, num_filters1=256, num_filters2=1024)
|
|
for i in xrange(2, res4_num + 1, 1):
|
|
tmp = bottleneck_block(
|
|
name="res4_" + str(i),
|
|
input=tmp,
|
|
num_filters1=256,
|
|
num_filters2=1024)
|
|
|
|
# conv5_x: 7x7
|
|
tmp = mid_projection(
|
|
name="res5_1", input=tmp, num_filters1=512, num_filters2=2048)
|
|
for i in xrange(2, res5_num + 1, 1):
|
|
tmp = bottleneck_block(
|
|
name="res5_" + str(i),
|
|
input=tmp,
|
|
num_filters1=512,
|
|
num_filters2=2048)
|
|
|
|
tmp = img_pool_layer(
|
|
name='avgpool',
|
|
input=tmp,
|
|
pool_size=7,
|
|
stride=1,
|
|
pool_type=AvgPooling())
|
|
|
|
return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())
|
|
|
|
|
|
if layer_num == 50:
|
|
resnet = deep_res_net(3, 4, 6, 3)
|
|
elif layer_num == 101:
|
|
resnet = deep_res_net(3, 4, 23, 3)
|
|
elif layer_num == 152:
|
|
resnet = deep_res_net(3, 8, 36, 3)
|
|
else:
|
|
print("Wrong layer number.")
|
|
|
|
if is_infer:
|
|
outputs(resnet)
|
|
else:
|
|
lbl = data_layer(name="label", size=num_class)
|
|
loss = cross_entropy(name='loss', input=resnet, label=lbl)
|
|
outputs(loss)
|