modify wide&deep

pull/1738/head
wukesong 5 years ago
parent c52576336a
commit 9a710b20d7

@ -0,0 +1,17 @@
#!/bin/bash
# bash run_multinpu_train.sh
execute_path=$(pwd)
export RANK_TABLE_FILE=${execute_path}/rank_table_8p.json
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=${execute_path}/rank_table_8p.json
for((i=0;i<=7;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
pytest -s ${execute_path}/train_and_test_multinpu.py >train_deep$i.log 2>&1 &
done

@ -82,7 +82,7 @@ def test_train_eval(config):
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)
out = model.eval(ds_eval)

@ -30,7 +30,7 @@ from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
context.set_context(mode=GRAPH_MODE, device_target="Davinci", save_graph=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
init()
@ -71,8 +71,8 @@ def test_train_eval():
test_train_eval
"""
np.random.seed(1000)
config = WideDeepConfig
data_path = Config.data_path
config = WideDeepConfig()
data_path = config.data_path
batch_size = config.batch_size
epochs = config.epochs
print("epochs is {}".format(epochs))
@ -94,8 +94,14 @@ def test_train_eval():
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config=config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=1, keep_checkpoint_max=5)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
out = model.eval(ds_eval)
print("=====" * 5 + "model.eval() initialized: {}".format(out))
model.train(epochs, ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb])
if __name__ == "__main__":
test_train_eval()

Loading…
Cancel
Save