|
|
|
@ -124,11 +124,15 @@ def train_and_eval(config):
|
|
|
|
|
eval_callback = EvalCallBack(
|
|
|
|
|
model, ds_eval, auc_metric, config)
|
|
|
|
|
|
|
|
|
|
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
|
|
|
|
|
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
|
|
|
|
|
|
|
|
|
|
callback = LossCallBack(config=config, per_print_times=20)
|
|
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
|
|
|
|
keep_checkpoint_max=5, integrated_save=False)
|
|
|
|
|
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
|
|
|
|
|
directory=config.ckpt_path + '/ckpt_' + str(get_rank()) + '/', config=ckptconfig)
|
|
|
|
|
directory=os.path.join(config.ckpt_path, 'ckpt_' + str(get_rank())), config=ckptconfig)
|
|
|
|
|
|
|
|
|
|
context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt)
|
|
|
|
|
callback_list = [TimeMonitor(
|
|
|
|
|
ds_train.get_dataset_size()), eval_callback, callback]
|
|
|
|
|