|
|
|
@ -127,7 +127,7 @@ def parse_args(cloud_args=None):
|
|
|
|
|
# logging and checkpoint related
|
|
|
|
|
parser.add_argument('--log_interval', type=int, default=100, help='logging interval')
|
|
|
|
|
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
|
|
|
|
|
parser.add_argument('--ckpt_interval', type=int, default=2, help='ckpt_interval')
|
|
|
|
|
parser.add_argument('--ckpt_interval', type=int, default=5, help='ckpt_interval')
|
|
|
|
|
parser.add_argument('--is_save_on_master', type=int, default=1, help='save ckpt on master or all rank')
|
|
|
|
|
|
|
|
|
|
# distributed related
|
|
|
|
@ -200,12 +200,12 @@ if __name__ == '__main__':
|
|
|
|
|
device_num = args.group_size
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
|
|
|
|
mirror_mean=True)
|
|
|
|
|
parameter_broadcast=True, mirror_mean=True)
|
|
|
|
|
else:
|
|
|
|
|
context.set_context(device_id=args.device_id)
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
|
|
|
|
|
|
|
|
|
# select for master rank save ckpt or all rank save, compatiable for model parallel
|
|
|
|
|
# select for master rank save ckpt or all rank save, compatible for model parallel
|
|
|
|
|
args.rank_save_ckpt_flag = 0
|
|
|
|
|
if args.is_save_on_master:
|
|
|
|
|
if args.rank == 0:
|
|
|
|
|