diff --git a/model_zoo/official/recommend/wide_and_deep/requirements.txt b/model_zoo/official/recommend/wide_and_deep/requirements.txt new file mode 100644 index 0000000000..5ab9999be1 --- /dev/null +++ b/model_zoo/official/recommend/wide_and_deep/requirements.txt @@ -0,0 +1,3 @@ +numpy +pandas +sklearn diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py index 5a7cf8c718..9e70cd1d68 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py @@ -109,8 +109,11 @@ def train_and_eval(config): directory=config.ckpt_path, config=ckptconfig) out = model.eval(ds_eval) print("=====" * 5 + "model.eval() initialized: {}".format(out)) + callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] + if get_rank() == 0: + callback_list.append(ckpoint_cb) model.train(epochs, ds_train, - callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb], + callbacks=callback_list, sink_size=ds_train.get_dataset_size()) diff --git a/model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt b/model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt index 065ca418fd..9060256bb8 100644 --- a/model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt +++ b/model_zoo/official/recommend/wide_and_deep_multitable/requirements.txt @@ -1,3 +1,4 @@ numpy pandas pickle +sklearn