wide&deep only save 0ckpt in data parallel

pull/4682/head
yao_yf 5 years ago
parent 245415f5bd
commit 8a0c47d367

@ -109,8 +109,11 @@ def train_and_eval(config):
directory=config.ckpt_path, config=ckptconfig)
out = model.eval(ds_eval)
print("=====" * 5 + "model.eval() initialized: {}".format(out))
callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback]
if get_rank() == 0:
callback_list.append(ckpoint_cb)
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb],
callbacks=callback_list,
sink_size=ds_train.get_dataset_size())

Loading…
Cancel
Save