|
|
|
@ -98,24 +98,24 @@ if __name__ == '__main__':
|
|
|
|
|
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
# init distributed
|
|
|
|
|
if args_opt.is_distributed:
|
|
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
|
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
|
|
|
|
rank = get_rank()
|
|
|
|
|
group_size = get_group_size()
|
|
|
|
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
|
|
|
|
|
init()
|
|
|
|
|
else:
|
|
|
|
|
rank = 0
|
|
|
|
|
group_size = 1
|
|
|
|
|
context.set_context(device_id=0)
|
|
|
|
|
|
|
|
|
|
if args_opt.device_target == "Ascend":
|
|
|
|
|
#train on Ascend
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=False)
|
|
|
|
|
|
|
|
|
|
# init distributed
|
|
|
|
|
if args_opt.is_distributed:
|
|
|
|
|
if os.getenv('DEVICE_ID', "not_set").isdigit():
|
|
|
|
|
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
|
|
|
|
|
init()
|
|
|
|
|
rank = get_rank()
|
|
|
|
|
group_size = get_group_size()
|
|
|
|
|
parallel_mode = ParallelMode.DATA_PARALLEL
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=group_size, gradients_mean=True)
|
|
|
|
|
else:
|
|
|
|
|
rank = 0
|
|
|
|
|
group_size = 1
|
|
|
|
|
context.set_context(device_id=0)
|
|
|
|
|
|
|
|
|
|
# define network
|
|
|
|
|
net = xception(class_num=config.class_num)
|
|
|
|
|
net.to_float(mstype.float16)
|
|
|
|
|