From 8a0c47d3671c9f86ce2a982ebea1f50469408aa3 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Tue, 18 Aug 2020 19:44:23 +0800 Subject: [PATCH] wide&deep only save 0ckpt in data parallel --- model_zoo/official/recommend/wide_and_deep/requirements.txt | 3 +++ .../recommend/wide_and_deep/train_and_eval_distribute.py | 5 ++++- .../recommend/wide_and_deep_multitable/requirements.txt | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 model_zoo/official/recommend/wide_and_deep/requirements.txt diff --git a/model_zoo/official/recommend/wide_and_deep/requirements.txt b/model_zoo/official/recommend/wide_and_deep/requirements.txt new file mode 100644 index 0000000000..5ab9999be1 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/requirements.txt @@ -0,0 +1,3 @@ +numpy +pandas +sklearn diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py index 5a7cf8c718..9e70cd1d68 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py @@ -109,8 +109,11 @@ def train_and_eval(config): directory=config.ckpt_path, config=ckptconfig) out = model.eval(ds_eval) print("=====" * 5 + "model.eval() initialized: {}".format(out)) + callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] + if get_rank() == 0: + callback_list.append(ckpoint_cb) model.train(epochs, ds_train, - callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb], + callbacks=callback_list, sink_size=ds_train.get_dataset_size()) diff --git a/model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt b/model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt index 065ca418fd..9060256bb8 100644 --- a/model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt +++ b/model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt @@ -1,3 +1,4 @@ numpy pandas pickle +sklearn