|
|
@ -30,7 +30,7 @@ from src.metrics import AUCMetric
|
|
|
|
from src.config import WideDeepConfig
|
|
|
|
from src.config import WideDeepConfig
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
context.set_context(mode=GRAPH_MODE, device_target="Davinci", save_graph=True)
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
|
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
|
|
|
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
|
|
|
|
init()
|
|
|
|
init()
|
|
|
|
|
|
|
|
|
|
|
@ -71,8 +71,8 @@ def test_train_eval():
|
|
|
|
test_train_eval
|
|
|
|
test_train_eval
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
np.random.seed(1000)
|
|
|
|
np.random.seed(1000)
|
|
|
|
config = WideDeepConfig
|
|
|
|
config = WideDeepConfig()
|
|
|
|
data_path = Config.data_path
|
|
|
|
data_path = config.data_path
|
|
|
|
batch_size = config.batch_size
|
|
|
|
batch_size = config.batch_size
|
|
|
|
epochs = config.epochs
|
|
|
|
epochs = config.epochs
|
|
|
|
print("epochs is {}".format(epochs))
|
|
|
|
print("epochs is {}".format(epochs))
|
|
|
@ -94,8 +94,14 @@ def test_train_eval():
|
|
|
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
|
|
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
|
|
|
|
|
|
|
|
|
|
|
|
callback = LossCallBack(config=config)
|
|
|
|
callback = LossCallBack(config=config)
|
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5)
|
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
|
|
|
directory=config.ckpt_path, config=ckptconfig)
|
|
|
|
directory=config.ckpt_path, config=ckptconfig)
|
|
|
|
|
|
|
|
out = model.eval(ds_eval)
|
|
|
|
|
|
|
|
print("=====" * 5 + "model.eval() initialized: {}".format(out))
|
|
|
|
model.train(epochs, ds_train,
|
|
|
|
model.train(epochs, ds_train,
|
|
|
|
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
|
|
|
|
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
test_train_eval()
|
|
|
|