You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
9.0 KiB
261 lines
9.0 KiB
9 years ago
|
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from paddle.trainer_config_helpers import *
|
||
|
|
||
|
"""
|
||
|
paper: https://arxiv.org/abs/1512.03385
|
||
|
"""
|
||
|
is_test = get_config_arg("is_test", bool, False)
|
||
|
is_predict = get_config_arg("is_predict", bool, False)
|
||
|
data_provider = get_config_arg("data_provider", bool, True)
|
||
|
layer_num = get_config_arg("layer_num", int, 50)
|
||
|
|
||
|
if not is_predict and data_provider:
|
||
|
train_list = 'train.list' if not is_test else None
|
||
|
# mean.meta is mean file of ImageNet dataset.
|
||
|
# mean.meta size : 3 x 224 x 224.
|
||
|
# If you use three mean value, set like:
|
||
|
# "mean_value:103.939,116.779,123.68;"
|
||
|
args={
|
||
|
'mean_meta': "model/mean_meta_224/mean.meta",
|
||
|
'image_size': 224, 'crop_size': 224,
|
||
|
'color': True,'swap_channel:': [2, 1, 0]}
|
||
|
define_py_data_sources2(train_list,
|
||
|
'example/test.list',
|
||
|
module="example.image_list_provider",
|
||
|
obj="processData",
|
||
|
args=args)
|
||
|
|
||
|
batch_size = 1
|
||
|
learning_rate = 0.1 / batch_size
|
||
|
momentum = 0.9
|
||
|
weight_decay = 0.0001 * batch_size
|
||
|
default_momentum(momentum)
|
||
|
default_decay_rate(weight_decay)
|
||
|
|
||
|
Settings(
|
||
|
algorithm='sgd',
|
||
|
batch_size=batch_size,
|
||
|
learning_rate=learning_rate,
|
||
|
|
||
|
# set the appropriate parameters according your schedule
|
||
|
learning_method='momentum',
|
||
|
learning_rate_decay_a=0.5,
|
||
|
learning_rate_decay_b=1200000 * 10,
|
||
|
learning_rate_schedule="discexp",
|
||
|
)
|
||
|
|
||
|
|
||
|
def conv_bn_layer(name, input, filter_size, num_filters,
|
||
|
stride, padding, channels=None,
|
||
|
active_type=ReluActivation()):
|
||
|
"""
|
||
|
A wrapper for conv layer with batch normalization layers.
|
||
|
Note:
|
||
|
conv layer has no activation.
|
||
|
"""
|
||
|
|
||
|
tmp = img_conv_layer(name=name + "_conv",
|
||
|
input=input,
|
||
|
filter_size=filter_size,
|
||
|
num_channels=channels,
|
||
|
num_filters=num_filters,
|
||
|
stride=stride,
|
||
|
padding=padding,
|
||
|
act=LinearActivation(),
|
||
|
bias_attr=False)
|
||
|
return batch_norm_layer(name=name + "_bn",
|
||
|
input=tmp,
|
||
|
act=active_type,
|
||
|
use_global_stats=is_test)
|
||
|
|
||
|
|
||
|
def bottleneck_block(name, input, num_filters1, num_filters2):
|
||
|
"""
|
||
|
A wrapper for bottlenect building block in ResNet.
|
||
|
Last conv_bn_layer has no activation.
|
||
|
Addto layer has activation of relu.
|
||
|
"""
|
||
|
last_name = conv_bn_layer(name=name + '_branch2a',
|
||
|
input=input,
|
||
|
filter_size=1,
|
||
|
num_filters=num_filters1,
|
||
|
stride=1,
|
||
|
padding=0)
|
||
|
last_name = conv_bn_layer(name=name + '_branch2b',
|
||
|
input=last_name,
|
||
|
filter_size=3,
|
||
|
num_filters=num_filters1,
|
||
|
stride=1,
|
||
|
padding=1)
|
||
|
last_name = conv_bn_layer(name=name + '_branch2c',
|
||
|
input=last_name,
|
||
|
filter_size=1,
|
||
|
num_filters=num_filters2,
|
||
|
stride=1,
|
||
|
padding=0,
|
||
|
active_type=LinearActivation())
|
||
|
|
||
|
return addto_layer(name=name + "_addto",
|
||
|
input=[input, last_name],
|
||
|
act=ReluActivation())
|
||
|
|
||
|
|
||
|
def mid_projection(name, input, num_filters1, num_filters2, stride=2):
|
||
|
"""
|
||
|
A wrapper for middile projection in ResNet.
|
||
|
projection shortcuts are used for increasing dimensions,
|
||
|
and other shortcuts are identity
|
||
|
branch1: projection shortcuts are used for increasing
|
||
|
dimensions, has no activation.
|
||
|
branch2x: bottleneck building block, shortcuts are identity.
|
||
|
"""
|
||
|
# stride = 2
|
||
|
branch1 = conv_bn_layer(name=name + '_branch1',
|
||
|
input=input,
|
||
|
filter_size=1,
|
||
|
num_filters=num_filters2,
|
||
|
stride=stride,
|
||
|
padding=0,
|
||
|
active_type=LinearActivation())
|
||
|
|
||
|
last_name = conv_bn_layer(name=name + '_branch2a',
|
||
|
input=input,
|
||
|
filter_size=1,
|
||
|
num_filters=num_filters1,
|
||
|
stride=stride,
|
||
|
padding=0)
|
||
|
last_name = conv_bn_layer(name=name + '_branch2b',
|
||
|
input=last_name,
|
||
|
filter_size=3,
|
||
|
num_filters=num_filters1,
|
||
|
stride=1,
|
||
|
padding=1)
|
||
|
|
||
|
last_name = conv_bn_layer(name=name + '_branch2c',
|
||
|
input=last_name,
|
||
|
filter_size=1,
|
||
|
num_filters=num_filters2,
|
||
|
stride=1,
|
||
|
padding=0,
|
||
|
active_type=LinearActivation())
|
||
|
|
||
|
return addto_layer(name=name + "_addto",
|
||
|
input=[branch1, last_name],
|
||
|
act=ReluActivation())
|
||
|
|
||
|
|
||
|
def deep_res_net(res2_num=3, res3_num=4, res4_num=6, res5_num=3):
|
||
|
"""
|
||
|
A wrapper for 50,101,152 layers of ResNet.
|
||
|
res2_num: number of blocks stacked in conv2_x
|
||
|
res3_num: number of blocks stacked in conv3_x
|
||
|
res4_num: number of blocks stacked in conv4_x
|
||
|
res5_num: number of blocks stacked in conv5_x
|
||
|
"""
|
||
|
# For ImageNet
|
||
|
# conv1: 112x112
|
||
|
img = data_layer(name='input', size=224 * 224 * 3)
|
||
|
tmp = conv_bn_layer("conv1", img,
|
||
|
filter_size=7,
|
||
|
channels=3,
|
||
|
num_filters=64,
|
||
|
stride=2,
|
||
|
padding=3)
|
||
|
tmp = img_pool_layer(name="pool1", input=tmp, pool_size=3, stride=2)
|
||
|
|
||
|
# conv2_x: 56x56
|
||
|
tmp = mid_projection(name="res2_1",
|
||
|
input=tmp,
|
||
|
num_filters1=64,
|
||
|
num_filters2=256,
|
||
|
stride=1)
|
||
|
for i in xrange(2, res2_num + 1, 1):
|
||
|
tmp = bottleneck_block(name="res2_" + str(i),
|
||
|
input=tmp,
|
||
|
num_filters1=64,
|
||
|
num_filters2=256)
|
||
|
|
||
|
# conv3_x: 28x28
|
||
|
tmp = mid_projection(name="res3_1",
|
||
|
input=tmp,
|
||
|
num_filters1=128,
|
||
|
num_filters2=512)
|
||
|
for i in xrange(2, res3_num + 1, 1):
|
||
|
tmp = bottleneck_block(name="res3_" + str(i),
|
||
|
input=tmp, num_filters1=128,
|
||
|
num_filters2=512)
|
||
|
|
||
|
# conv4_x: 14x14
|
||
|
tmp = mid_projection(name="res4_1", input=tmp,
|
||
|
num_filters1=256, num_filters2=1024)
|
||
|
for i in xrange(2, res4_num + 1, 1):
|
||
|
tmp = bottleneck_block(name="res4_" + str(i),
|
||
|
input=tmp,
|
||
|
num_filters1=256,
|
||
|
num_filters2=1024)
|
||
|
|
||
|
# conv5_x: 7x7
|
||
|
tmp = mid_projection(name="res5_1", input=tmp,
|
||
|
num_filters1=512, num_filters2=2048)
|
||
|
for i in xrange(2, res5_num + 1, 1):
|
||
|
tmp = bottleneck_block(name="res5_" + str(i),
|
||
|
input=tmp, num_filters1=512,
|
||
|
num_filters2=2048)
|
||
|
|
||
|
tmp = img_pool_layer(name='avgpool',
|
||
|
input=tmp,
|
||
|
pool_size=7,
|
||
|
stride=1,
|
||
|
pool_type=AvgPooling())
|
||
|
|
||
|
output = fc_layer(name='output',
|
||
|
input=tmp,
|
||
|
size=1000,
|
||
|
act=SoftmaxActivation())
|
||
|
|
||
|
if not is_predict:
|
||
|
classification_cost(input=output, label=data_layer(name='label',
|
||
|
size=1))
|
||
|
|
||
|
|
||
|
def res_net_50():
|
||
|
deep_res_net(3, 4, 6, 3)
|
||
|
|
||
|
|
||
|
def res_net_101():
|
||
|
deep_res_net(3, 4, 23, 3)
|
||
|
|
||
|
|
||
|
def res_net_152():
|
||
|
deep_res_net(3, 8, 36, 3)
|
||
|
|
||
|
|
||
|
if not is_predict:
|
||
|
Inputs("input", "label")
|
||
|
else:
|
||
|
Inputs("input")
|
||
|
# Outputs("cost-softmax" if not is_predict else "output")
|
||
|
Outputs("res5_3_branch2c_conv", "res5_3_branch2c_bn")
|
||
|
|
||
|
if layer_num == 50:
|
||
|
res_net_50()
|
||
|
elif layer_num == 101:
|
||
|
res_net_101()
|
||
|
elif layer_num == 152:
|
||
|
res_net_152()
|
||
|
else:
|
||
|
print("Wrong layer number.")
|