|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
""" test_training """
|
|
|
|
|
import os
|
|
|
|
|
from mindspore import Model, context
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
|
|
|
|
|
|
|
|
|
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
|
|
|
|
|
from src.callbacks import LossCallBack
|
|
|
|
@ -75,7 +75,7 @@ def test_train(configure):
|
|
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=1,
|
|
|
|
|
keep_checkpoint_max=5)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=configure.ckpt_path, config=ckptconfig)
|
|
|
|
|
model.train(epochs, ds_train, callbacks=[callback, ckpoint_cb])
|
|
|
|
|
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|