noah recommender system model AutoDis

pull/10285/head
chenbo 5 years ago
parent 06f80f7043
commit ffd8da9911

@ -73,6 +73,8 @@ In order to facilitate developers to enjoy the benefits of MindSpore framework,
- [CenterNet](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/centernet/README.md)
- [Natural Language Processing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp)
- [DS-CNN](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/nlp/dscnn/README.md)
- [Recommender Systems](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend)
- [AutoDis](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis/README.md)
- [Audio](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio)
- [FCN-4](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/audio/fcn-4/README.md)
- [High Performance Computing](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/hpc)

@ -0,0 +1,230 @@
# Contents
- [AutoDis Description](#AutoDis-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [AutoDis Description](#contents)
The common methods for numerical feature embedding are Normalization and Discretization. The former shares a single embedding for intra-field features and the latter transforms the features into categorical form through various discretization approaches. However, the first approach surfers from low capacity and the second one limits performance as well because the discretization rule cannot be optimized with the ultimate goal of CTR model.
To fill the gap of representing numerical features, in this paper, we propose AutoDis, a framework that discretizes features in numerical fields automatically and is optimized with CTR models in an end-to-end manner. Specifically, we introduce a set of meta-embeddings for each numerical field to model the relationship among the intra-field features and propose an automatic differentiable discretization and aggregation approach to capture the correlations between the numerical features and meta-embeddings. AutoDis is a valid framework to work with various popular deep CTR models and is able to improve the recommendation performance significantly.
[Paper](https://arxiv.org/abs/2012.08986): Huifeng Guo*, Bo Chen*, Ruiming Tang, Zhenguo Li, Xiuqiang He. AutoDis: Automatic Discretization for Embedding Numerical Features in CTR Prediction
# [Model Architecture](#contents)
AutoDis leverages a set of meta-embeddings for each numerical field, which are shared among all the intra-field feature values. Meta-embeddings learn the relationship across different feature values in this field with a manageable number of embedding parameters. Utilizing meta-embedding is able to avoid explosive embedding parameters introduced by assigning each numerical feature with an independent embedding simply. Besides, the embedding of a numerical feature is designed as a differentiable aggregation over the shared meta-embeddings, so that the discretization of numerical features can be optimized with the ultimate goal of deep CTR models in an end-to-end manner.
# [Dataset](#contents)
- [1] A dataset [Criteo](https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz) used in Huifeng Guo, Ruiming Tang, Yunming Ye, Zhenguo Li, Xiuqiang He. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction[J]. 2017.
# [Environment Requirements](#contents)
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
# [Quick Start](#contents)
After installing MindSpore via the official website, you can start training and evaluation as follows:
- runing on Ascend
```python
# run training example
python train.py \
--dataset_path='dataset/train' \
--ckpt_path='./checkpoint' \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='Ascend' \
--do_eval=True > ms_log/output.log 2>&1 &
# run evaluation example
python eval.py \
--dataset_path='dataset/test' \
--checkpoint_path='./checkpoint/autodis.ckpt' \
--device_target='Ascend' > ms_log/eval_output.log 2>&1 &
OR
sh scripts/run_eval.sh 0 Ascend /dataset_path /checkpoint_path/autodis.ckpt
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
# [Script Description](#contents)
## [Script and Sample Code](#contents)
```bash
.
└─autodis
├─README.md
├─mindspore_hub_conf.md # config for mindspore hub
├─scripts
├─run_standalone_train.sh # launch standalone training(1p) in Ascend or GPU
└─run_eval.sh # launch evaluating in Ascend or GPU
├─src
├─__init__.py # python init file
├─config.py # parameter configuration
├─callback.py # define callback function
├─autodis.py # AutoDis network
├─dataset.py # create dataset for AutoDis
├─eval.py # eval net
└─train.py # train net
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in config.py
- train parameters
```python
optional arguments:
-h, --help show this help message and exit
--dataset_path DATASET_PATH
Dataset path
--ckpt_path CKPT_PATH
Checkpoint path
--eval_file_name EVAL_FILE_NAME
Auc log file path. Default: "./auc.log"
--loss_file_name LOSS_FILE_NAME
Loss log file path. Default: "./loss.log"
--do_eval DO_EVAL Do evaluation or not. Default: True
--device_target DEVICE_TARGET
Ascend or GPU. Default: Ascend
```
- eval parameters
```bash
optional arguments:
-h, --help show this help message and exit
--checkpoint_path CHECKPOINT_PATH
Checkpoint file path
--dataset_path DATASET_PATH
Dataset path
--device_target DEVICE_TARGET
Ascend or GPU. Default: Ascend
```
## [Training Process](#contents)
### Training
- running on Ascend
```python
python train.py \
--dataset_path='dataset/train' \
--ckpt_path='./checkpoint' \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='Ascend' \
--do_eval=True > ms_log/output.log 2>&1 &
```
The python command above will run in the background, you can view the results through the file `ms_log/output.log`.
After training, you'll get some checkpoint files under `./checkpoint` folder by default. The loss value are saved in loss.log file.
```txt
2020-12-10 14:58:04 epoch: 1 step: 41257, loss is 0.44559600949287415
2020-12-10 15:06:59 epoch: 2 step: 41257, loss is 0.4370603561401367
...
```
The model checkpoint will be saved in the current directory.
## [Evaluation Process](#contents)
### Evaluation
- evaluation on dataset when running on Ascend
Before running the command below, please check the checkpoint path used for evaluation.
```python
python eval.py \
--dataset_path='dataset/test' \
--checkpoint_path='./checkpoint/autodis.ckpt' \
--device_target='Ascend' > ms_log/eval_output.log 2>&1 &
OR
sh scripts/run_eval.sh 0 Ascend /dataset_path /checkpoint_path/autodis.ckpt
```
The above python command will run in the background. You can view the results through the file "eval_output.log". The accuracy is saved in auc.log file.
```txt
{'result': {'AUC': 0.8109881454077731, 'eval_time': 27.72783327102661s}}
```
# [Model Description](#contents)
## [Performance](#contents)
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | AutoDis |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G |
| uploaded Date | 12/12/2020 (month/day/year) |
| MindSpore Version | 1.1.0 |
| Dataset | [1] |
| Training Parameters | epoch=15, batch_size=1000, lr=1e-5 |
| Optimizer | Adam |
| Loss Function | Sigmoid Cross Entropy With Logits |
| outputs | Accuracy |
| Loss | 0.42 |
| Speed | 1pc: 8.16 ms/step; |
| Total time | 1pc: 90 mins; |
| Parameters (M) | 16.5 |
| Checkpoint for Fine tuning | 191M (.ckpt file) |
| Scripts | [AutoDis script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/recommend/autodis) |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | AutoDis |
| Resource | Ascend 910 |
| Uploaded Date | 12/12/2020 (month/day/year) |
| MindSpore Version | 0.3.0-alpha |
| Dataset | [1] |
| batch_size | 1000 |
| outputs | accuracy |
| AUC | 1pc: 0.8112; |
| Model for inference | 191M (.ckpt file) |
# [Description of Random Situation](#contents)
We set the random seed before training in train.py.
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -0,0 +1,68 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval_criteo."""
import os
import sys
import time
import argparse
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.autodis import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset, DataType
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend"],
help='Default: Ascend')
args_opt, _ = parser.parse_known_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
def add_write(file_path, print_str):
with open(file_path, 'a+', encoding='utf-8') as file_out:
file_out.write(print_str + '\n')
if __name__ == '__main__':
data_config = DataConfig()
model_config = ModelConfig()
train_config = TrainConfig()
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1, batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net()
train_net.set_train()
eval_net.set_train(False)
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(eval_net, param_dict)
start = time.time()
res = model.eval(ds_eval)
eval_time = time.time() - start
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
out_str = f'{time_str} AUC: {list(res.values())[0]}, eval time: {eval_time}s.'
print(out_str)
add_write('./auc.log', str(out_str))

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""hub config."""
from src.autodis import ModelBuilder
from src.config import ModelConfig, TrainConfig
def create_network(name, *args, **kwargs):
if name == 'autodis':
model_config = ModelConfig()
train_config = TrainConfig()
model_builder = ModelBuilder(model_config, train_config)
_, autodis_eval_net = model_builder.get_train_eval_net()
return autodis_eval_net
raise NotImplementedError(f"{name} is not implemented in the repo")

@ -0,0 +1,34 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_eval.sh DEVICE_ID DEVICE_TARGET DATASET_PATH CHECKPOINT_PATH"
echo "for example: sh scripts/run_eval.sh 0 GPU /dataset_path /checkpoint_path"
echo "After running the script, the network runs in the background, The log will be generated in ms_log/eval_output.log"
export DEVICE_ID=$1
DEVICE_TARGET=$2
DATA_URL=$3
CHECKPOINT_PATH=$4
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python -u eval.py \
--dataset_path=$DATA_URL \
--checkpoint_path=$CHECKPOINT_PATH \
--device_target=$DEVICE_TARGET > ms_log/eval_output.log 2>&1 &

@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_standalone_train.sh DEVICE_ID/CUDA_VISIBLE_DEVICES DEVICE_TARGET DATASET_PATH"
echo "for example: sh scripts/run_standalone_train.sh 0 GPU /dataset_path"
echo "After running the script, the network runs in the background, The log will be generated in ms_log/output.log"
DEVICE_TARGET=$2
if [ "$DEVICE_TARGET" = "GPU" ]
then
export CUDA_VISIBLE_DEVICES=$1
fi
if [ "$DEVICE_TARGET" = "Ascend" ]
then
export DEVICE_ID=$1
fi
DATA_URL=$3
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python -u train.py \
--dataset_path=$DATA_URL \
--ckpt_path="checkpoint" \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target=$DEVICE_TARGET \
--do_eval=True > ms_log/output.log 2>&1 &

File diff suppressed because it is too large Load Diff

@ -0,0 +1,108 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the License);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# httpwww.apache.orglicensesLICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an AS IS BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Defined callback for DeepFM.
"""
import time
from mindspore.train.callback import Callback
def add_write(file_path, out_str):
with open(file_path, 'a+', encoding='utf-8') as file_out:
file_out.write(out_str + '\n')
class EvalCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note
If per_print_times is 0 do not print loss.
"""
def __init__(self, model, eval_dataset, auc_metric, eval_file_path):
super(EvalCallBack, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.aucMetric = auc_metric
self.aucMetric.clear()
self.eval_file_path = eval_file_path
def epoch_end(self, run_context):
start_time = time.time()
out = self.model.eval(self.eval_dataset)
eval_time = int(time.time() - start_time)
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
out_str = "{} EvalCallBack metric{}; eval_time{}s".format(
time_str, out.values(), eval_time)
print(out_str)
add_write(self.eval_file_path, out_str)
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note
If per_print_times is 0 do not print loss.
Args
loss_file_path (str) The file absolute path, to save as loss_file;
per_print_times (int) Print loss every times. Default 1.
"""
def __init__(self, loss_file_path, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self.loss_file_path = loss_file_path
self._per_print_times = per_print_times
def step_end(self, run_context):
"""Monitor the loss in training."""
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
with open(self.loss_file_path, "a+") as loss_file:
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
loss_file.write("{} epoch: {} step: {}, loss is {}\n".format(
time_str, cb_params.cur_epoch_num, cur_step_in_epoch, loss))
print("epoch: {} step: {}, loss is {}\n".format(
cb_params.cur_epoch_num, cur_step_in_epoch, loss))
class TimeMonitor(Callback):
"""
Time monitor for calculating cost of each epoch.
Args
data_size (int) step size of an epoch.
"""
def __init__(self, data_size):
super(TimeMonitor, self).__init__()
self.data_size = data_size
def epoch_begin(self, run_context):
self.epoch_time = time.time()
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
step_mseconds = (time.time() - self.step_time) * 1000
print(f"step time {step_mseconds}", flush=True)

@ -0,0 +1,64 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py and eval.py
"""
class DataConfig:
"""
Define parameters of dataset.
"""
data_vocab_size = 184965
train_num_of_parts = 21
test_num_of_parts = 3
batch_size = 1000
data_field_size = 39
# dataset format, 1: mindrecord, 2: tfrecord, 3: h5
data_format = 2
class ModelConfig:
"""
Define parameters of model.
"""
batch_size = DataConfig.batch_size
data_field_size = DataConfig.data_field_size
data_vocab_size = DataConfig.data_vocab_size
data_emb_dim = 80
deep_layer_args = [[400, 400, 512], "relu"]
init_args = [-0.01, 0.01]
weight_bias_init = ['normal', 'normal']
keep_prob = 0.9
split_index = 13
hash_size = 20
temperature = 1e-5
class TrainConfig:
"""
Define parameters of training.
"""
batch_size = DataConfig.batch_size
l2_coef = 1e-6
learning_rate = 1e-5
epsilon = 1e-8
loss_scale = 1024.0
train_epochs = 15
save_checkpoint = True
ckpt_file_name_prefix = "autodis"
save_checkpoint_steps = 1
keep_checkpoint_max = 15
eval_callback = True
loss_callback = True

File diff suppressed because it is too large Load Diff

@ -0,0 +1,119 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_criteo."""
import os
import sys
import argparse
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
from mindspore.train.model import Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.common import set_seed
#from mindspore.profiler import Profiler
from src.autodis import ModelBuilder, AUCMetric
from src.config import DataConfig, ModelConfig, TrainConfig
from src.dataset import create_dataset, DataType
from src.callback import EvalCallBack, LossCallBack
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
parser = argparse.ArgumentParser(description='CTR Prediction')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
parser.add_argument('--ckpt_path', type=str, default=None, help='Checkpoint path')
parser.add_argument('--eval_file_name', type=str, default="./auc.log",
help='Auc log file path. Default: "./auc.log"')
parser.add_argument('--loss_file_name', type=str, default="./loss.log",
help='Loss log file path. Default: "./loss.log"')
parser.add_argument('--do_eval', type=str, default='True', choices=["True", "False"],
help='Do evaluation or not, only support "True" or "False". Default: "True"')
parser.add_argument('--device_target', type=str, default="Ascend", choices=["Ascend"],
help='Default: Ascend')
args_opt, _ = parser.parse_known_args()
args_opt.do_eval = args_opt.do_eval == 'True'
rank_size = int(os.environ.get("RANK_SIZE", 1))
set_seed(1)
if __name__ == '__main__':
data_config = DataConfig()
model_config = ModelConfig()
train_config = TrainConfig()
if rank_size > 1:
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
init()
rank_id = int(os.environ.get('RANK_ID'))
else:
print("Unsupported device_target ", args_opt.device_target)
exit()
else:
if args_opt.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=device_id)
else:
print("Unsupported device_target ", args_opt.device_target)
exit()
rank_size = None
rank_id = None
# Init Profiler
#profiler = Profiler(output_path='./data', is_detail=True, is_show_op_path=False, subgraph='all')
ds_train = create_dataset(args_opt.dataset_path,
train_mode=True,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format),
rank_size=rank_size,
rank_id=rank_id)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
steps_size = ds_train.get_dataset_size()
model_builder = ModelBuilder(ModelConfig, TrainConfig)
train_net, eval_net = model_builder.get_train_eval_net()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
loss_callback = LossCallBack(loss_file_path=args_opt.loss_file_name)
callback_list = [time_callback, loss_callback]
if train_config.save_checkpoint:
if rank_size:
train_config.ckpt_file_name_prefix = train_config.ckpt_file_name_prefix + str(get_rank())
args_opt.ckpt_path = os.path.join(args_opt.ckpt_path, 'ckpt_' + str(get_rank()) + '/')
config_ck = CheckpointConfig(save_checkpoint_steps=train_config.save_checkpoint_steps,
keep_checkpoint_max=train_config.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix=train_config.ckpt_file_name_prefix,
directory=args_opt.ckpt_path,
config=config_ck)
callback_list.append(ckpt_cb)
if args_opt.do_eval:
ds_eval = create_dataset(args_opt.dataset_path, train_mode=False,
epochs=1,
batch_size=train_config.batch_size,
data_type=DataType(data_config.data_format))
eval_callback = EvalCallBack(model, ds_eval, auc_metric,
eval_file_path=args_opt.eval_file_name)
callback_list.append(eval_callback)
model.train(train_config.train_epochs, ds_train, callbacks=callback_list)
Loading…
Cancel
Save