Add rank id to the ckpt path

pull/10357/head
huangxinjing 4 years ago
parent 61ed05f133
commit c218b0314c

@ -40,7 +40,7 @@ def argparse_init():
parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout")
parser.add_argument("--output_path", type=str, default="./output/")
parser.add_argument("--ckpt_path", type=str, default="./", help="The location of the checkpoint file.")
parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt",
parser.add_argument("--stra_ckpt", type=str, default="./checkpoints",
help="The strategy checkpoint file.")
parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.")
parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.")

@ -124,11 +124,15 @@ def train_and_eval(config):
eval_callback = EvalCallBack(
model, ds_eval, auc_metric, config)
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
callback = LossCallBack(config=config, per_print_times=20)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
keep_checkpoint_max=5, integrated_save=False)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig)
directory=os.path.join(config.ckpt_path, 'ckpt_' + str(get_rank())), config=ckptconfig)
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
callback_list = [TimeMonitor(
ds_train.get_dataset_size()), eval_callback, callback]

@ -115,6 +115,10 @@ def train_and_eval(config):
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
if cache_enable:
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
@ -129,9 +133,6 @@ def train_and_eval(config):
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/',
config=ckptconfig)
if cache_enable:
config.stra_ckpt = './stra_ckpt_' + str(get_rank()) + '/strategy.ckpt'
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if get_rank() == 0:
callback_list.append(ckpoint_cb)

Loading…
Cancel
Save