parent
ded9608f6d
commit
92c1b2bd31
@ -0,0 +1,14 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
@ -0,0 +1,66 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_criteo."""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
|
||||
from mindspore import context
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.deepfm import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.dataset import create_dataset
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
|
||||
|
||||
def add_write(file_path, print_str):
|
||||
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||
file_out.write(print_str + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
epochs=1, batch_size=train_config.batch_size)
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
train_net.set_train()
|
||||
eval_net.set_train(False)
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(eval_net, param_dict)
|
||||
|
||||
start = time.time()
|
||||
res = model.eval(ds_eval)
|
||||
eval_time = time.time() - start
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.'
|
||||
print(out_str)
|
||||
add_write('./auc.log', str(out_str))
|
@ -0,0 +1,44 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH MINDSPORE_HCCL_CONFIG_PAHT"
|
||||
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in logx/output.log"
|
||||
|
||||
|
||||
export RANK_SIZE=$1
|
||||
DATA_URL=$2
|
||||
export MINDSPORE_HCCL_CONFIG_PAHT=$3
|
||||
|
||||
for ((i=0; i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf log$i
|
||||
mkdir ./log$i
|
||||
cp *.py ./log$i
|
||||
cp -r src ./log$i
|
||||
cd ./log$i || exit
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python -u train.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--ckpt_path="checkpoint" \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
--do_eval=True > output.log 2>&1 &
|
||||
cd ../
|
||||
done
|
@ -0,0 +1,32 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_eval.sh DEVICE_ID DATASET_PATH CHECKPOINT_PATH"
|
||||
echo "for example: sh scripts/run_eval.sh 0 /dataset_path /checkpoint_path"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log"
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_URL=$2
|
||||
CHECKPOINT_PATH=$3
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
python -u eval.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--checkpoint_path=$CHECKPOINT_PATH > ms_log/eval_output.log 2>&1 &
|
@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
echo "Please run the script as: "
|
||||
echo "sh scripts/run_standalone_train.sh DEVICE_ID DATASET_PATH"
|
||||
echo "for example: sh scripts/run_standalone_train.sh 0 /dataset_path"
|
||||
echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log"
|
||||
|
||||
export DEVICE_ID=$1
|
||||
DATA_URL=$2
|
||||
|
||||
mkdir -p ms_log
|
||||
CUR_DIR=`pwd`
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
|
||||
python -u train.py \
|
||||
--dataset_path=$DATA_URL \
|
||||
--ckpt_path="checkpoint" \
|
||||
--eval_file_name='auc.log' \
|
||||
--loss_file_name='loss.log' \
|
||||
--do_eval=True > ms_log/output.log 2>&1 &
|
@ -0,0 +1,14 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
@ -0,0 +1,107 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the License);
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# httpwww.apache.orglicensesLICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an AS IS BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Defined callback for DeepFM.
|
||||
"""
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
|
||||
|
||||
def add_write(file_path, out_str):
|
||||
with open(file_path, 'a+', encoding='utf-8') as file_out:
|
||||
file_out.write(out_str + '\n')
|
||||
|
||||
|
||||
class EvalCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note
|
||||
If per_print_times is 0 do not print loss.
|
||||
"""
|
||||
def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
|
||||
super(EvalCallBack, self).__init__()
|
||||
self.model = model
|
||||
self.eval_dataset = eval_dataset
|
||||
self.aucMetric = auc_metric
|
||||
self.aucMetric.clear()
|
||||
self.eval_file_path = eval_file_path
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
start_time = time.time()
|
||||
out = self.model.eval(self.eval_dataset)
|
||||
eval_time = int(time.time() - start_time)
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
|
||||
time_str, out.values(), eval_time)
|
||||
print(out_str)
|
||||
add_write(self.eval_file_path, out_str)
|
||||
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
If the loss is NAN or INF terminating training.
|
||||
Note
|
||||
If per_print_times is 0 do not print loss.
|
||||
Args
|
||||
loss_file_path (str) The file absolute path, to save as loss_file;
|
||||
per_print_times (int) Print loss every times. Default 1.
|
||||
"""
|
||||
def __init__(self, loss_file_path, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self.loss_file_path = loss_file_path
|
||||
self._per_print_times = per_print_times
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||
with open(self.loss_file_path, "a+") as loss_file:
|
||||
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
|
||||
time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
print("epoch: {} step: {}, loss is {}\n".format(
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch, loss))
|
||||
|
||||
|
||||
class TimeMonitor(Callback):
|
||||
"""
|
||||
Time monitor for calculating cost of each epoch.
|
||||
Args
|
||||
data_size (int) step size of an epoch.
|
||||
"""
|
||||
def __init__(self, data_size):
|
||||
super(TimeMonitor, self).__init__()
|
||||
self.data_size = data_size
|
||||
|
||||
def epoch_begin(self, run_context):
|
||||
self.epoch_time = time.time()
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
epoch_mseconds = (time.time() - self.epoch_time) * 1000
|
||||
per_step_mseconds = epoch_mseconds / self.data_size
|
||||
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
|
||||
|
||||
def step_begin(self, run_context):
|
||||
self.step_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
step_mseconds = (time.time() - self.step_time) * 1000
|
||||
print(f"step time {step_mseconds}", flush=True)
|
@ -0,0 +1,62 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
|
||||
|
||||
class DataConfig:
|
||||
"""
|
||||
Define parameters of dataset.
|
||||
"""
|
||||
data_vocab_size = 184965
|
||||
train_num_of_parts = 21
|
||||
test_num_of_parts = 3
|
||||
batch_size = 1000
|
||||
data_field_size = 39
|
||||
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
|
||||
data_format = 2
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""
|
||||
Define parameters of model.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
data_field_size = DataConfig.data_field_size
|
||||
data_vocab_size = DataConfig.data_vocab_size
|
||||
data_emb_dim = 80
|
||||
deep_layer_args = [[400, 400, 512], "relu"]
|
||||
init_args = [-0.01, 0.01]
|
||||
weight_bias_init = ['normal', 'normal']
|
||||
keep_prob = 0.9
|
||||
|
||||
|
||||
class TrainConfig:
|
||||
"""
|
||||
Define parameters of training.
|
||||
"""
|
||||
batch_size = DataConfig.batch_size
|
||||
l2_coef = 1e-6
|
||||
learning_rate = 1e-5
|
||||
epsilon = 1e-8
|
||||
loss_scale = 1024.0
|
||||
train_epochs = 15
|
||||
save_checkpoint = True
|
||||
ckpt_file_name_prefix = "deepfm"
|
||||
save_checkpoint_steps = 1
|
||||
keep_checkpoint_max = 15
|
||||
eval_callback = True
|
||||
loss_callback = True
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,91 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train_criteo."""
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
from mindspore import context, ParallelMode
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
|
||||
|
||||
from src.deepfm import ModelBuilder, AUCMetric
|
||||
from src.config import DataConfig, ModelConfig, TrainConfig
|
||||
from src.dataset import create_dataset, DataType
|
||||
from src.callback import EvalCallBack, LossCallBack
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
parser = argparse.ArgumentParser(description='CTR Prediction')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
|
||||
parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path')
|
||||
parser.add_argument('--eval_file_name', type=str, default="./auc.log", help='eval file path')
|
||||
parser.add_argument('--loss_file_name', type=str, default="./loss.log", help='loss file path')
|
||||
parser.add_argument('--do_eval', type=bool, default=True, help='Do evaluation or not.')
|
||||
|
||||
args_opt, _ = parser.parse_known_args()
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_config = DataConfig()
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
rank_size = int(os.environ.get("RANK_SIZE", 1))
|
||||
if rank_size > 1:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
|
||||
init()
|
||||
rank_id = int(os.environ.get('RANK_ID'))
|
||||
else:
|
||||
rank_size = None
|
||||
rank_id = None
|
||||
|
||||
ds_train = create_dataset(args_opt.dataset_path,
|
||||
train_mode=True,
|
||||
epochs=train_config.train_epochs,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format),
|
||||
rank_size=rank_size,
|
||||
rank_id=rank_id)
|
||||
|
||||
model_builder = ModelBuilder(ModelConfig, TrainConfig)
|
||||
train_net, eval_net = model_builder.get_train_eval_net()
|
||||
auc_metric = AUCMetric()
|
||||
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
|
||||
|
||||
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||||
loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
|
||||
callback_list = [time_callback, loss_callback]
|
||||
|
||||
if train_config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=train_config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
|
||||
directory=args_opt.ckpt_path,
|
||||
config=config_ck)
|
||||
callback_list.append(ckpt_cb)
|
||||
|
||||
if args_opt.do_eval:
|
||||
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
|
||||
epochs=train_config.train_epochs,
|
||||
batch_size=train_config.batch_size,
|
||||
data_type=DataType(data_config.data_format))
|
||||
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
|
||||
eval_file_path=args_opt.eval_file_name)
|
||||
callback_list.append(eval_callback)
|
||||
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
|
Loading…
Reference in new issue