|
|
@ -121,12 +121,12 @@ def train_and_eval(config):
|
|
|
|
model = Model(train_net, eval_network=eval_net,
|
|
|
|
model = Model(train_net, eval_network=eval_net,
|
|
|
|
metrics={"auc": auc_metric})
|
|
|
|
metrics={"auc": auc_metric})
|
|
|
|
|
|
|
|
|
|
|
|
eval_callback = EvalCallBack(
|
|
|
|
|
|
|
|
model, ds_eval, auc_metric, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
|
|
|
|
# Save strategy ckpts according to the rank id, this must be done before initializing the callbacks.
|
|
|
|
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
|
|
|
|
config.stra_ckpt = os.path.join(config.stra_ckpt + "-{}".format(get_rank()), "strategy.ckpt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_callback = EvalCallBack(
|
|
|
|
|
|
|
|
model, ds_eval, auc_metric, config)
|
|
|
|
|
|
|
|
|
|
|
|
callback = LossCallBack(config=config, per_print_times=20)
|
|
|
|
callback = LossCallBack(config=config, per_print_times=20)
|
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
|
|
|
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs,
|
|
|
|
keep_checkpoint_max=5, integrated_save=False)
|
|
|
|
keep_checkpoint_max=5, integrated_save=False)
|
|
|
@ -146,10 +146,11 @@ if __name__ == "__main__":
|
|
|
|
wide_deep_config = WideDeepConfig()
|
|
|
|
wide_deep_config = WideDeepConfig()
|
|
|
|
wide_deep_config.argparse_init()
|
|
|
|
wide_deep_config.argparse_init()
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
context.set_context(mode=context.GRAPH_MODE,
|
|
|
|
device_target=wide_deep_config.device_target, save_graphs=True)
|
|
|
|
device_target=wide_deep_config.device_target)
|
|
|
|
context.set_context(variable_memory_max_size="24GB")
|
|
|
|
context.set_context(variable_memory_max_size="24GB")
|
|
|
|
context.set_context(enable_sparse=True)
|
|
|
|
context.set_context(enable_sparse=True)
|
|
|
|
init()
|
|
|
|
init()
|
|
|
|
|
|
|
|
context.set_context(save_graphs_path='./graphs_of_device_id_' + str(get_rank()), save_graphs=True)
|
|
|
|
if wide_deep_config.sparse:
|
|
|
|
if wide_deep_config.sparse:
|
|
|
|
context.set_auto_parallel_context(
|
|
|
|
context.set_auto_parallel_context(
|
|
|
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)
|
|
|
|
parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True)
|
|
|
|