diff --git a/model_zoo/wide_and_deep/train.py b/model_zoo/wide_and_deep/train.py index b3996e01cb..ac9750c547 100644 --- a/model_zoo/wide_and_deep/train.py +++ b/model_zoo/wide_and_deep/train.py @@ -14,7 +14,7 @@ """ test_training """ import os from mindspore import Model, context -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.callbacks import LossCallBack @@ -75,7 +75,7 @@ def test_train(configure): ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5) ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig) - model.train(epochs, ds_train, callbacks=[callback, ckpoint_cb]) + model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb]) if __name__ == "__main__":