|
|
|
@ -68,7 +68,8 @@ def run_pretrain():
|
|
|
|
|
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
|
|
|
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
|
|
|
|
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
|
|
|
|
parser.add_argument("--checkpoint_path", type=str, default="", help="Checkpoint file path")
|
|
|
|
|
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
|
|
|
|
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
|
|
|
|
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
|
|
|
|
"default is 1000.")
|
|
|
|
|
parser.add_argument("--train_steps", type=int, default=-1, help="Training Steps, default is -1, "
|
|
|
|
@ -81,7 +82,7 @@ def run_pretrain():
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
|
|
|
|
|
context.set_context(reserve_class_name_in_scope=False)
|
|
|
|
|
|
|
|
|
|
ckpt_save_dir = args_opt.checkpoint_path
|
|
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path
|
|
|
|
|
if args_opt.distribute == "true":
|
|
|
|
|
if args_opt.device_target == 'Ascend':
|
|
|
|
|
D.init('hccl')
|
|
|
|
@ -91,7 +92,7 @@ def run_pretrain():
|
|
|
|
|
D.init('nccl')
|
|
|
|
|
device_num = D.get_group_size()
|
|
|
|
|
rank = D.get_rank()
|
|
|
|
|
ckpt_save_dir = args_opt.checkpoint_path + 'ckpt_' + str(rank) + '/'
|
|
|
|
|
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
|
|
|
|
|
|
|
|
|
|
context.reset_auto_parallel_context()
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
|
|
|
|
@ -150,8 +151,8 @@ def run_pretrain():
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
|
|
|
|
|
callback.append(ckpoint_cb)
|
|
|
|
|
|
|
|
|
|
if args_opt.checkpoint_path:
|
|
|
|
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
|
|
|
|
if args_opt.load_checkpoint_path:
|
|
|
|
|
param_dict = load_checkpoint(args_opt.load_checkpoint_path)
|
|
|
|
|
load_param_into_net(netwithloss, param_dict)
|
|
|
|
|
|
|
|
|
|
if args_opt.enable_lossscale == "true":
|
|
|
|
|