add bgcf gpu

pull/11728/head
牛黄解毒片 4 years ago committed by panfengfeng
parent 54b8d53780
commit 1bff5f043d

File diff suppressed because it is too large Load Diff

@ -40,39 +40,46 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
所用数据集的统计信息摘要如下:
| | Amazon-Beauty |
| ------------------ | -----------------------:|
| 任务 | 推荐 |
| # 用户 | 7068 (1图) |
| # 物品 | 3570 |
| # 交互 | 79506 |
| # 训练数据 | 60818 |
| # 测试数据 | 18688 |
| # 密度 | 0.315% |
| | Amazon-Beauty |
| ------------------ | ------------------ |
| 任务 | 推荐 |
| # 用户 | 7068 (1图) |
| # 物品 | 3570 |
| # 交互 | 79506 |
| # 训练数据 | 60818 |
| # 测试数据 | 18688 |
| # 密度 | 0.315% |
- 数据准备
- 将数据集放到任意路径文件夹应该包含如下文件以Amazon-Beauty数据集为例
```text
.
└─data
├─ratings_Beauty.csv
```
- 为Amazon-Beauty生成MindRecord格式的数据集
```builddoutcfg
cd ./scripts
# SRC_PATH是您下载的数据集文件路径
sh run_process_data_ascend.sh [SRC_PATH]
```
- 启动
```text
# 为Amazon-Beauty生成MindRecord格式的数据集
sh ./run_process_data_ascend.sh ./data
```
# 特性
## 混合精度
@ -81,7 +88,7 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
# 环境要求
- 硬件Ascend
- 硬件Ascend/GPU
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
@ -104,6 +111,18 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
```
- GPU处理器环境运行
```text
# 使用Amazon-Beauty数据集运行训练示例
sh run_train_gpu.sh 0 dataset_path
# 使用Amazon-Beauty数据集运行评估示例
sh run_eval_gpu.sh 0 dataset_path
```
# 脚本说明
## 脚本及样例代码
@ -113,9 +132,11 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
└─bgcf
├─README.md
├─scripts
| ├─run_eval_ascend.sh # 启动评估
| ├─run_eval_ascend.sh # Ascend启动评估
| ├─run_eval_gpu.sh # GPU启动评估
| ├─run_process_data_ascend.sh # 生成MindRecord格式的数据集
| └─run_train_ascend.sh # 启动训练
| └─run_train_ascend.sh # Ascend启动训练
| └─run_train_gpu.sh # GPU启动训练
|
├─src
| ├─bgcf.py # BGCF模型
@ -178,7 +199,25 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
Epoch 598 iter 12 loss 3640.7612
Epoch 599 iter 12 loss 3654.9087
Epoch 600 iter 12 loss 3632.4585
...
```
- GPU处理器环境运行
```python
sh run_train_gpu.sh 0 dataset_path
```
训练结果将保存在脚本路径下文件夹名称以“train”开头。您可在日志中找到结果如下所示。
```python
Epoch 001 iter 12 loss 34696.242
Epoch 002 iter 12 loss 34275.508
Epoch 003 iter 12 loss 30620.635
Epoch 004 iter 12 loss 21628.908
```
@ -212,7 +251,23 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
sedp_@10:0.01896, sedp_@20:0.01504, nov_@10:7.57995, nov_@20:7.79439
epoch:600, recall_@10:0.09926, recall_@20:0.15080, ndcg_@10:0.07283, ndcg_@20:0.09016,
sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038
...
```
- GPU评估
```python
sh run_eval_gpu.sh 0 dataset_path
```
评估结果将保存在脚本路径下文件夹名称以“eval”开头。您可在日志中找到结果如下所示。
```python
epoch:680, recall_@10:0.10383, recall_@20:0.15524, ndcg_@10:0.07503, ndcg_@20:0.09249,
sedp_@10:0.01926, sedp_@20:0.01547, nov_@10:7.60851, nov_@20:7.81969
```
@ -220,19 +275,19 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
## 性能
| 参数 | BGCF |
| ------------------------------------ | ----------------------------------------- |
| 资源 | Ascend 910 |
| 上传日期 | 09/23/2020(月/日/年) |
| MindSpore版本 | 1.0.0 |
| 数据集 | Amazon-Beauty |
| 训练参数 | epoch=600 |
| 优化器 | Adam |
| 损失函数 | BPR loss |
| Recall@20 | 0.1534 |
| NDCG@20 | 0.0912 |
| 训练成本 | 25min |
| 脚本 | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) |
| 参数 | BGCF Ascend | BGCF GPU |
| -------------------------- | ------------------------------------------ | ------------------------------------------ |
| 资源 | Ascend 910 | Tesla V100-PCIE |
| 上传日期 | 09/23/2020(月/日/年) | 01/28/2021(月/日/年) |
| MindSpore版本 | 1.0.0 | Master(4b3e53b4) |
| 数据集 | Amazon-Beauty | Amazon-Beauty |
| 训练参数 | epoch=600,steps=12,batch_size=5000,lr=0.001| epoch=680,steps=12,batch_size=5000,lr=0.001|
| 优化器 | Adam | Adam |
| 损失函数 | BPR loss | BPR loss |
| Recall@20 | 0.1534 | 0.15524 |
| NDCG@20 | 0.0912 | 0.09249 |
| 训练成本 | 25min | 60min |
| 脚本 | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) |
# 随机情况说明

@ -19,6 +19,7 @@ import datetime
import mindspore.context as context
from mindspore.train.serialization import load_checkpoint
from mindspore.common import set_seed
from src.bgcf import BGCF
from src.utils import BGCFLogger
@ -27,6 +28,7 @@ from src.metrics import BGCFEvaluate
from src.callback import ForwardBGCF, TestBGCF
from src.dataset import TestGraphDataset, load_graph
set_seed(1)
def evaluation():
"""evaluation"""
@ -34,7 +36,8 @@ def evaluation():
num_item = train_graph.graph_info()["node_num"][1]
eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval):
for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval) \
if parser.device_target == "Ascend" else range(parser.num_epoch, parser.num_epoch+1):
bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
parser.embedded_dimension,
parser.activation,
@ -79,9 +82,10 @@ def evaluation():
if __name__ == "__main__":
parser = parser_args()
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=int(parser.device))
device_target=parser.device_target,
save_graphs=False)
if parser.device_target == "Ascend":
context.set_context(device_id=int(parser.device))
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -u unlimited
if [ $# -lt 2 ]
then
echo "Usage: sh run_eval_gpu.sh [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi
export DEVICE_NUM=1
DATASET_PATH=$2
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation"
export CUDA_VISIBLE_DEVICES="$1"
python eval.py --datapath=$DATASET_PATH --ckptpath=../ckpts \
--device_target='GPU' --num_epoch=680 \
--dist_reg=0 > log 2>&1 &
cd ..

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: sh run_train_gpu.sh [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi
export DEVICE_NUM=1
DATASET_PATH=$2
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
if [ -d "ckpts" ];
then
rm -rf ./ckpts
fi
mkdir ./ckpts
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
env > env.log
echo "start training"
export CUDA_VISIBLE_DEVICES="$1"
python train.py --datapath=$DATASET_PATH --ckptpath=../ckpts \
--device_target='GPU' --num_epoch=680 \
--dist_reg=0 > log 2>&1 &
cd ..

@ -49,4 +49,5 @@ def parser_args():
parser.add_argument("-emb", "--embedded_dimension", type=int, default=64, help="output embedding dim")
parser.add_argument('--dist_reg', type=float, default=0.003, help="distance loss coefficient")
parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target')
return parser.parse_args()

@ -21,6 +21,7 @@ from mindspore import Tensor
import mindspore.context as context
from mindspore.common import dtype as mstype
from mindspore.train.serialization import save_checkpoint
from mindspore.common import set_seed
from src.bgcf import BGCF
from src.config import parser_args
@ -28,6 +29,7 @@ from src.utils import convert_item_id
from src.callback import TrainBGCF
from src.dataset import load_graph, create_dataset
set_seed(1)
def train():
"""Train"""
@ -102,10 +104,13 @@ def train():
if __name__ == "__main__":
parser = parser_args()
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=int(parser.device))
device_target=parser.device_target,
save_graphs=False)
if parser.device_target == "Ascend":
context.set_context(device_id=int(parser.device))
train_graph, _, sampled_graph_list = load_graph(parser.datapath)
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,

Loading…
Cancel
Save