|
|
|
@ -50,17 +50,21 @@ de.config.set_seed(1)
|
|
|
|
|
|
|
|
|
|
if args_opt.net == "resnet50":
|
|
|
|
|
from src.resnet import resnet50 as resnet
|
|
|
|
|
|
|
|
|
|
if args_opt.dataset == "cifar10":
|
|
|
|
|
from src.config import config1 as config
|
|
|
|
|
from src.dataset import create_dataset1 as create_dataset
|
|
|
|
|
else:
|
|
|
|
|
from src.config import config2 as config
|
|
|
|
|
from src.dataset import create_dataset2 as create_dataset
|
|
|
|
|
else:
|
|
|
|
|
elif args_opt.net == "resnet101":
|
|
|
|
|
from src.resnet import resnet101 as resnet
|
|
|
|
|
from src.config import config3 as config
|
|
|
|
|
from src.dataset import create_dataset3 as create_dataset
|
|
|
|
|
else:
|
|
|
|
|
from src.resnet import se_resnet50 as resnet
|
|
|
|
|
from src.config import config4 as config
|
|
|
|
|
from src.dataset import create_dataset4 as create_dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
target = args_opt.device_target
|
|
|
|
@ -74,7 +78,7 @@ if __name__ == '__main__':
|
|
|
|
|
context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
|
|
|
|
|
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
mirror_mean=True)
|
|
|
|
|
if args_opt.net == "resnet50":
|
|
|
|
|
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
|
|
|
|
|
else:
|
|
|
|
|
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
|
|
|
|
@ -112,14 +116,10 @@ if __name__ == '__main__':
|
|
|
|
|
cell.weight.dtype)
|
|
|
|
|
|
|
|
|
|
# init lr
|
|
|
|
|
if args_opt.net == "resnet50":
|
|
|
|
|
if args_opt.dataset == "cifar10":
|
|
|
|
|
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
|
|
|
|
|
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
|
|
|
|
|
lr_decay_mode='poly')
|
|
|
|
|
else:
|
|
|
|
|
lr = get_lr(lr_init=config.lr_init, lr_end=0.0, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
|
|
|
|
|
total_epochs=config.epoch_size, steps_per_epoch=step_size, lr_decay_mode='cosine')
|
|
|
|
|
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
|
|
|
|
|
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
|
|
|
|
|
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
|
|
|
|
|
lr_decay_mode=config.lr_decay_mode)
|
|
|
|
|
else:
|
|
|
|
|
lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size,
|
|
|
|
|
config.pretrain_epoch_size * step_size)
|
|
|
|
|