add se block for resnet50

pull/4362/head
qujianwei 5 years ago
parent 0ae5eeb33d
commit b079e34e73

@ -38,17 +38,20 @@ de.config.set_seed(1)
if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet
if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
else:
elif args_opt.net == "resnet101":
from src.resnet import resnet101 as resnet
from src.config import config3 as config
from src.dataset import create_dataset3 as create_dataset
else:
from src.resnet import se_resnet50 as resnet
from src.config import config4 as config
from src.dataset import create_dataset4 as create_dataset
if __name__ == '__main__':
target = args_opt.device_target

@ -16,13 +16,13 @@
if [ $# != 4 ] && [ $# != 5 ]
then
echo "Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: sh run_distribute_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ]
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101"
echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50"
exit 1
fi
@ -38,6 +38,11 @@ then
exit 1
fi
if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ]
then
echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then

@ -16,13 +16,13 @@
if [ $# != 4 ]
then
echo "Usage: sh run_eval.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
echo "Usage: sh run_eval.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ]
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101"
echo "error: the selected net is neither resnet50 nor resnet101 nor se-resnet50"
exit 1
fi
@ -38,6 +38,11 @@ then
exit 1
fi
if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ]
then
echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then

@ -16,13 +16,13 @@
if [ $# != 3 ] && [ $# != 4 ]
then
echo "Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo "Usage: sh run_standalone_train.sh [resnet50|resnet101|se-resnet50] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit 1
fi
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ]
if [ $1 != "resnet50" ] && [ $1 != "resnet101" ] && [ $1 != "se-resnet50" ]
then
echo "error: the selected net is neither resnet50 nor resnet101"
echo "error: the selected net is neither resnet50 nor resnet101 and se-resnet50"
exit 1
fi
@ -38,6 +38,11 @@ then
exit 1
fi
if [ $1 == "se-resnet50" ] && [ $2 == "cifar10" ]
then
echo "error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then

@ -50,12 +50,12 @@ config2 = ed({
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 0,
"lr_decay_mode": "cosine",
"lr_decay_mode": "linear",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0,
"lr_max": 0.1
"lr_max": 0.1,
"lr_end": 0.0
})
# config for resent101, imagenet2012
@ -77,3 +77,25 @@ config3 = ed({
"label_smooth_factor": 0.1,
"lr": 0.1
})
# config for se-resnet50, imagenet2012
config4 = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 1024,
"momentum": 0.9,
"weight_decay": 1e-4,
"epoch_size": 28,
"pretrain_epoch_size": 1,
"save_checkpoint": True,
"save_checkpoint_epochs": 4,
"keep_checkpoint_max": 10,
"save_checkpoint_path": "./",
"warmup_epochs": 3,
"lr_decay_mode": "cosine",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.0,
"lr_max": 0.3,
"lr_end": 0.0001
})

@ -22,7 +22,6 @@ import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from mindspore.communication.management import init, get_rank, get_group_size
def create_dataset1(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or evaluate cifar10 dataset for resnet50
@ -191,6 +190,59 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target=
return ds
def create_dataset4(dataset_path, do_train, repeat_num=1, batch_size=32, target="Ascend"):
"""
create a train or eval imagenet2012 dataset for se-resnet50
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if target == "Ascend":
device_num, rank_id = _get_rank_info()
if device_num == 1:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=12, shuffle=True)
else:
ds = de.ImageFolderDatasetV2(dataset_path, num_parallel_workers=12, shuffle=True,
num_shards=device_num, shard_id=rank_id)
image_size = 224
mean = [123.68, 116.78, 103.94]
std = [1.0, 1.0, 1.0]
# define map operations
if do_train:
trans = [
C.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
else:
trans = [
C.Decode(),
C.Resize(292),
C.CenterCrop(256),
C.Normalize(mean=mean, std=std),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="image", num_parallel_workers=12, operations=trans)
ds = ds.map(input_columns="label", num_parallel_workers=12, operations=type_cast_op)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
def _get_rank_info():
"""

@ -62,6 +62,18 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:

File diff suppressed because it is too large Load Diff

@ -50,17 +50,21 @@ de.config.set_seed(1)
if args_opt.net == "resnet50":
from src.resnet import resnet50 as resnet
if args_opt.dataset == "cifar10":
from src.config import config1 as config
from src.dataset import create_dataset1 as create_dataset
else:
from src.config import config2 as config
from src.dataset import create_dataset2 as create_dataset
else:
elif args_opt.net == "resnet101":
from src.resnet import resnet101 as resnet
from src.config import config3 as config
from src.dataset import create_dataset3 as create_dataset
else:
from src.resnet import se_resnet50 as resnet
from src.config import config4 as config
from src.dataset import create_dataset4 as create_dataset
if __name__ == '__main__':
target = args_opt.device_target
@ -74,7 +78,7 @@ if __name__ == '__main__':
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
if args_opt.net == "resnet50":
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
@ -112,14 +116,10 @@ if __name__ == '__main__':
cell.weight.dtype)
# init lr
if args_opt.net == "resnet50":
if args_opt.dataset == "cifar10":
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode='poly')
else:
lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine')
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
else:
lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size,
config.pretrain_epoch_size * step_size)

Loading…
Cancel
Save