parent
ded9608f6d
commit
92c1b2bd31
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the License);
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# httpwww.apache.orglicensesLICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an AS IS BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""train_criteo."""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
|
||||||
|
from src.deepfm import ModelBuilder, AUCMetric
|
||||||
|
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||||
|
from src.dataset import create_dataset
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||||
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
|
||||||
|
args_opt, _ = parser.parse_known_args()
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||||
|
|
||||||
|
|
||||||
|
def add_write(file_path, print_str):
|
||||||
|
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||||
|
file_out.write(print_str + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
data_config = DataConfig()
|
||||||
|
model_config = ModelConfig()
|
||||||
|
train_config = TrainConfig()
|
||||||
|
|
||||||
|
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||||
|
epochs=1, batch_size=train_config.batch_size)
|
||||||
|
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||||
|
train_net, eval_net = model_builder.get_train_eval_net()
|
||||||
|
train_net.set_train()
|
||||||
|
eval_net.set_train(False)
|
||||||
|
auc_metric = AUCMetric()
|
||||||
|
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||||
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
|
load_param_into_net(eval_net, param_dict)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
res = model.eval(ds_eval)
|
||||||
|
eval_time = time.time() - start
|
||||||
|
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.'
|
||||||
|
print(out_str)
|
||||||
|
add_write('./auc.log', str(out_str))
|
@ -0,0 +1,44 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH MINDSPORE_HCCL_CONFIG_PAHT"
|
||||||
|
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json"
|
||||||
|
echo "After running the script, the network runs in the background, The log will be generated in logx/output.log"
|
||||||
|
|
||||||
|
|
||||||
|
export RANK_SIZE=$1
|
||||||
|
DATA_URL=$2
|
||||||
|
export MINDSPORE_HCCL_CONFIG_PAHT=$3
|
||||||
|
|
||||||
|
for ((i=0; i<RANK_SIZE;i++))
|
||||||
|
do
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
export RANK_ID=$i
|
||||||
|
rm -rf log$i
|
||||||
|
mkdir ./log$i
|
||||||
|
cp *.py ./log$i
|
||||||
|
cp -r src ./log$i
|
||||||
|
cd ./log$i || exit
|
||||||
|
echo "start training for rank $i, device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python -u train.py \
|
||||||
|
--dataset_path=$DATA_URL \
|
||||||
|
--ckpt_path="checkpoint" \
|
||||||
|
--eval_file_name='auc.log' \
|
||||||
|
--loss_file_name='loss.log' \
|
||||||
|
--do_eval=True > output.log 2>&1 &
|
||||||
|
cd ../
|
||||||
|
done
|
@ -0,0 +1,32 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH"
|
||||||
|
echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path"
|
||||||
|
echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log"
|
||||||
|
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
DATA_URL=$2
|
||||||
|
CHECKPOINT_PATH=$3
|
||||||
|
|
||||||
|
mkdir -p ms_log
|
||||||
|
CUR_DIR=`pwd`
|
||||||
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
|
export GLOG_logtostderr=0
|
||||||
|
|
||||||
|
python -u eval.py \
|
||||||
|
--dataset_path=$DATA_URL \
|
||||||
|
--checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 &
|
@ -0,0 +1,34 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
echo "Please run the script as: "
|
||||||
|
echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH"
|
||||||
|
echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path"
|
||||||
|
echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log"
|
||||||
|
|
||||||
|
export DEVICE_ID=$1
|
||||||
|
DATA_URL=$2
|
||||||
|
|
||||||
|
mkdir -p ms_log
|
||||||
|
CUR_DIR=`pwd`
|
||||||
|
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||||
|
export GLOG_logtostderr=0
|
||||||
|
|
||||||
|
python -u train.py \
|
||||||
|
--dataset_path=$DATA_URL \
|
||||||
|
--ckpt_path="checkpoint" \
|
||||||
|
--eval_file_name='auc.log' \
|
||||||
|
--loss_file_name='loss.log' \
|
||||||
|
--do_eval=True > ms_log/output.log 2>&1 &
|
@ -0,0 +1,14 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the License);
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# httpwww.apache.orglicensesLICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an AS IS BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
@ -0,0 +1,107 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the License);
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# httpwww.apache.orglicensesLICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an AS IS BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
Defined callback for DeepFM.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
|
||||||
|
|
||||||
|
def add_write(file_path, out_str):
|
||||||
|
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||||
|
file_out.write(out_str + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
class EvalCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Monitor the loss in training.
|
||||||
|
If the loss is NAN or INF terminating training.
|
||||||
|
Note
|
||||||
|
If per_print_times is 0 do not print loss.
|
||||||
|
"""
|
||||||
|
def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
|
||||||
|
super(EvalCallBack, self).__init__()
|
||||||
|
self.model = model
|
||||||
|
self.eval_dataset = eval_dataset
|
||||||
|
self.aucMetric = auc_metric
|
||||||
|
self.aucMetric.clear()
|
||||||
|
self.eval_file_path = eval_file_path
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
start_time = time.time()
|
||||||
|
out = self.model.eval(self.eval_dataset)
|
||||||
|
eval_time = int(time.time() - start_time)
|
||||||
|
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
|
||||||
|
time_str, out.values(), eval_time)
|
||||||
|
print(out_str)
|
||||||
|
add_write(self.eval_file_path, out_str)
|
||||||
|
|
||||||
|
|
||||||
|
class LossCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Monitor the loss in training.
|
||||||
|
If the loss is NAN or INF terminating training.
|
||||||
|
Note
|
||||||
|
If per_print_times is 0 do not print loss.
|
||||||
|
Args
|
||||||
|
loss_file_path (str) The file absolute path, to save as loss_file;
|
||||||
|
per_print_times (int) Print loss every times. Default 1.
|
||||||
|
"""
|
||||||
|
def __init__(self, loss_file_path, per_print_times=1):
|
||||||
|
super(LossCallBack, self).__init__()
|
||||||
|
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||||
|
raise ValueError("print_step must be int and >= 0.")
|
||||||
|
self.loss_file_path = loss_file_path
|
||||||
|
self._per_print_times = per_print_times
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
loss = cb_params.net_outputs.asnumpy()
|
||||||
|
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||||
|
cur_num = cb_params.cur_step_num
|
||||||
|
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||||
|
with open(self.loss_file_path, "a+") as loss_file:
|
||||||
|
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
|
||||||
|
time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||||
|
print("epoch: {} step: {}, loss is {}\n".format(
|
||||||
|
cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||||
|
|
||||||
|
|
||||||
|
class TimeMonitor(Callback):
|
||||||
|
"""
|
||||||
|
Time monitor for calculating cost of each epoch.
|
||||||
|
Args
|
||||||
|
data_size (int) step size of an epoch.
|
||||||
|
"""
|
||||||
|
def __init__(self, data_size):
|
||||||
|
super(TimeMonitor, self).__init__()
|
||||||
|
self.data_size = data_size
|
||||||
|
|
||||||
|
def epoch_begin(self, run_context):
|
||||||
|
self.epoch_time = time.time()
|
||||||
|
|
||||||
|
def epoch_end(self, run_context):
|
||||||
|
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||||
|
per_step_mseconds = epoch_mseconds / self.data_size
|
||||||
|
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
||||||
|
|
||||||
|
def step_begin(self, run_context):
|
||||||
|
self.step_time = time.time()
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
step_mseconds = (time.time() - self.step_time) * 1000
|
||||||
|
print(f"step time {step_mseconds}", flush=True)
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""
|
||||||
|
network config setting, will be used in train.py and eval.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class DataConfig:
|
||||||
|
"""
|
||||||
|
Define parameters of dataset.
|
||||||
|
"""
|
||||||
|
data_vocab_size = 184965
|
||||||
|
train_num_of_parts = 21
|
||||||
|
test_num_of_parts = 3
|
||||||
|
batch_size = 1000
|
||||||
|
data_field_size = 39
|
||||||
|
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
|
||||||
|
data_format = 2
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig:
|
||||||
|
"""
|
||||||
|
Define parameters of model.
|
||||||
|
"""
|
||||||
|
batch_size = DataConfig.batch_size
|
||||||
|
data_field_size = DataConfig.data_field_size
|
||||||
|
data_vocab_size = DataConfig.data_vocab_size
|
||||||
|
data_emb_dim = 80
|
||||||
|
deep_layer_args = [[400, 400, 512], "relu"]
|
||||||
|
init_args = [-0.01, 0.01]
|
||||||
|
weight_bias_init = ['normal', 'normal']
|
||||||
|
keep_prob = 0.9
|
||||||
|
|
||||||
|
|
||||||
|
class TrainConfig:
|
||||||
|
"""
|
||||||
|
Define parameters of training.
|
||||||
|
"""
|
||||||
|
batch_size = DataConfig.batch_size
|
||||||
|
l2_coef = 1e-6
|
||||||
|
learning_rate = 1e-5
|
||||||
|
epsilon = 1e-8
|
||||||
|
loss_scale = 1024.0
|
||||||
|
train_epochs = 15
|
||||||
|
save_checkpoint = True
|
||||||
|
ckpt_file_name_prefix = "deepfm"
|
||||||
|
save_checkpoint_steps = 1
|
||||||
|
keep_checkpoint_max = 15
|
||||||
|
eval_callback = True
|
||||||
|
loss_callback = True
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""train_criteo."""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from mindspore import context, ParallelMode
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||||
|
|
||||||
|
from src.deepfm import ModelBuilder, AUCMetric
|
||||||
|
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||||
|
from src.dataset import create_dataset, DataType
|
||||||
|
from src.callback import EvalCallBack, LossCallBack
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||||
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||||
|
parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path')
|
||||||
|
parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path')
|
||||||
|
parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path')
|
||||||
|
parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.')
|
||||||
|
|
||||||
|
args_opt, _ = parser.parse_known_args()
|
||||||
|
device_id = int(os.getenv('DEVICE_ID'))
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
data_config = DataConfig()
|
||||||
|
model_config = ModelConfig()
|
||||||
|
train_config = TrainConfig()
|
||||||
|
|
||||||
|
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||||
|
if rank_size > 1:
|
||||||
|
context.reset_auto_parallel_context()
|
||||||
|
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
|
||||||
|
init()
|
||||||
|
rank_id = int(os.environ.get('RANK_ID'))
|
||||||
|
else:
|
||||||
|
rank_size = None
|
||||||
|
rank_id = None
|
||||||
|
|
||||||
|
ds_train = create_dataset(args_opt.dataset_path,
|
||||||
|
train_mode=True,
|
||||||
|
epochs=train_config.train_epochs,
|
||||||
|
batch_size=train_config.batch_size,
|
||||||
|
data_type=DataType(data_config.data_format),
|
||||||
|
rank_size=rank_size,
|
||||||
|
rank_id=rank_id)
|
||||||
|
|
||||||
|
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||||
|
train_net, eval_net = model_builder.get_train_eval_net()
|
||||||
|
auc_metric = AUCMetric()
|
||||||
|
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||||
|
|
||||||
|
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||||
|
loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
|
||||||
|
callback_list = [time_callback, loss_callback]
|
||||||
|
|
||||||
|
if train_config.save_checkpoint:
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
|
||||||
|
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||||
|
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
|
||||||
|
directory=args_opt.ckpt_path,
|
||||||
|
config=config_ck)
|
||||||
|
callback_list.append(ckpt_cb)
|
||||||
|
|
||||||
|
if args_opt.do_eval:
|
||||||
|
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||||
|
epochs=train_config.train_epochs,
|
||||||
|
batch_size=train_config.batch_size,
|
||||||
|
data_type=DataType(data_config.data_format))
|
||||||
|
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
|
||||||
|
eval_file_path=args_opt.eval_file_name)
|
||||||
|
callback_list.append(eval_callback)
|
||||||
|
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
|
Loading…
Reference in new issue