diff --git a/model_zoo/official/recommend/wide_and_deep/src/config.py b/model_zoo/official/recommend/wide_and_deep/src/config.py index 1c206854b6..decaee567f 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/config.py +++ b/model_zoo/official/recommend/wide_and_deep/src/config.py @@ -40,7 +40,7 @@ def argparse_init(): parser.add_argument("--dropout_flag", type=int, default=0, help="Enable dropout") parser.add_argument("--output_path", type=str, default="./output/") parser.add_argument("--ckpt_path", type=str, default="./", help="The location of the checkpoint file.") - parser.add_argument("--stra_ckpt", type=str, default="./checkpoints/strategy.ckpt", + parser.add_argument("--stra_ckpt", type=str, default="./checkpoints", help="The strategy checkpoint file.") parser.add_argument("--eval_file_name", type=str, default="eval.log", help="Eval output file.") parser.add_argument("--loss_file_name", type=str, default="loss.log", help="Loss output file.") diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py index 6aca38a273..dee2a95ff6 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py @@ -124,11 +124,15 @@ def train_and_eval(config): eval_callback = EvalCallBack( model, ds_eval, auc_metric, config) + # Save strategy ckpts according to the rank id, this must be done before initializing the callbacks. + config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt") + callback = LossCallBack(config=config, per_print_times=20) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, keep_checkpoint_max=5, integrated_save=False) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', - directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig) + directory=os.path.join(config.ckpt_path, 'ckpt_' + str(get_rank())), config=ckptconfig) + context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) callback_list = [TimeMonitor( ds_train.get_dataset_size()), eval_callback, callback] diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py index 0bad048a6e..19a93d2ee0 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server_distribute.py @@ -115,6 +115,10 @@ def train_and_eval(config): model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + if cache_enable: + config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt") + context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) + eval_callback = EvalCallBack(model, ds_eval, auc_metric, config) callback = LossCallBack(config=config) @@ -129,9 +133,6 @@ def train_and_eval(config): ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig) - if cache_enable: - config.stra_ckpt = './stra_ckpt_' + str(get_rank()) + '/strategy.ckpt' - context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] if get_rank() == 0: callback_list.append(ckpoint_cb)