|
|
|
@ -51,7 +51,7 @@ def run_pretrain():
|
|
|
|
|
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
|
|
|
|
|
parser.add_argument("--enable_data_sink", type=str, default="true", help="Enable data sink, default is true.")
|
|
|
|
|
parser.add_argument("--data_sink_steps", type=int, default="1", help="Sink steps for each epoch, default is 1.")
|
|
|
|
|
parser.add_argument("--save_checkpoint_path", type=str, default="", help="Save checkpoint path")
|
|
|
|
|
parser.add_argument("--save_checkpoint_path", type=str, default=None, help="Save checkpoint path")
|
|
|
|
|
parser.add_argument("--load_checkpoint_path", type=str, default="", help="Load checkpoint file path")
|
|
|
|
|
parser.add_argument("--save_checkpoint_steps", type=int, default=1000, help="Save checkpoint steps, "
|
|
|
|
|
"default is 1000.")
|
|
|
|
@ -142,7 +142,7 @@ def run_pretrain():
|
|
|
|
|
raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay]".
|
|
|
|
|
format(cfg.optimizer))
|
|
|
|
|
callback = [TimeMonitor(args_opt.data_sink_steps), LossCallBack()]
|
|
|
|
|
if args_opt.enable_save_ckpt == "true":
|
|
|
|
|
if args_opt.enable_save_ckpt == "true" and args_opt.device_id % min(8, device_num) == 0:
|
|
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
|
|
|
|
|
keep_checkpoint_max=args_opt.save_checkpoint_num)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
|
|
|
|
|