add bgcf gpu

pull/11728/head
牛黄解毒片 5 years ago committed by panfengfeng
parent 54b8d53780
commit 1bff5f043d

@ -18,7 +18,9 @@
- [Performance](#performance) - [Performance](#performance)
- [Description of random situation](#description-of-random-situation) - [Description of random situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage) - [ModelZoo Homepage](#modelzoo-homepage)
<!--TOC --> <!--TOC -->
# [Bayesian Graph Collaborative Filtering](#contents) # [Bayesian Graph Collaborative Filtering](#contents)
Bayesian Graph Collaborative Filtering(BGCF) was proposed in 2020 by Sun J, Guo W, Zhang D et al. By naturally incorporating the Bayesian Graph Collaborative Filtering(BGCF) was proposed in 2020 by Sun J, Guo W, Zhang D et al. By naturally incorporating the
@ -33,12 +35,14 @@ Specially, BGCF contains two main modules. The first is sampling, which produce
aggregate the neighbors sampling from nodes consisting of mean aggregator and attention aggregator. aggregate the neighbors sampling from nodes consisting of mean aggregator and attention aggregator.
# [Dataset](#contents) # [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
- Dataset size: - Dataset size:
Statistics of dataset used are summarized as below: Statistics of dataset used are summarized as below:
| | Amazon-Beauty | | | Amazon-Beauty |
| ------------------ | -----------------------:| | ------------------ | ----------------------|
| Task | Recommendation | | Task | Recommendation |
| # User | 7068 (1 graph) | | # User | 7068 (1 graph) |
| # Item | 3570 | | # Item | 3570 |
@ -49,20 +53,22 @@ Note that you can run the scripts based on the dataset mentioned in original pap
- Data Preparation - Data Preparation
- Place the dataset to any path you want, the folder should include files as follows(we use Amazon-Beauty dataset as an example)" - Place the dataset to any path you want, the folder should include files as follows(we use Amazon-Beauty dataset as an example)"
```
```python
. .
└─data └─data
├─ratings_Beauty.csv ├─ratings_Beauty.csv
``` ```
- Generate dataset in mindrecord format for Amazon-Beauty. - Generate dataset in mindrecord format for Amazon-Beauty.
```builddoutcfg ```builddoutcfg
cd ./scripts cd ./scripts
# SRC_PATH is the dataset file path you download. # SRC_PATH is the dataset file path you download.
sh run_process_data_ascend.sh [SRC_PATH] sh run_process_data_ascend.sh [SRC_PATH]
``` ```
# [Features](#contents) # [Features](#contents)
## Mixed Precision ## Mixed Precision
@ -71,7 +77,7 @@ To ultilize the strong computation power of Ascend chip, and accelerate the trai
# [Environment Requirements](#contents) # [Environment Requirements](#contents)
- Hardward (Ascend) - Hardware (Ascend/GPU)
- Framework - Framework
- [MindSpore](https://www.mindspore.cn/install/en) - [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below: - For more information, please check the resources below:
@ -84,7 +90,7 @@ After installing MindSpore via the official website and Dataset is correctly gen
- running on Ascend - running on Ascend
``` ```python
# run training example with Amazon-Beauty dataset # run training example with Amazon-Beauty dataset
sh run_train_ascend.sh sh run_train_ascend.sh
@ -92,6 +98,16 @@ After installing MindSpore via the official website and Dataset is correctly gen
sh run_eval_ascend.sh sh run_eval_ascend.sh
``` ```
- running on GPU
```python
# run training example with Amazon-Beauty dataset
sh run_train_gpu.sh 0 dataset_path
# run evaluation example with Amazon-Beauty dataset
sh run_eval_gpu.sh 0 dataset_path
```
# [Script Description](#contents) # [Script Description](#contents)
## [Script and Sample Code](#contents) ## [Script and Sample Code](#contents)
@ -101,9 +117,11 @@ After installing MindSpore via the official website and Dataset is correctly gen
└─bgcf └─bgcf
├─README.md ├─README.md
├─scripts ├─scripts
| ├─run_eval_ascend.sh # Launch evaluation | ├─run_eval_ascend.sh # Launch evaluation in ascend
| ├─run_eval_gpu.sh # Launch evaluation in gpu
| ├─run_process_data_ascend.sh # Generate dataset in mindrecord format | ├─run_process_data_ascend.sh # Generate dataset in mindrecord format
| └─run_train_ascend.sh # Launch training | └─run_train_ascend.sh # Launch training in ascend
| └─run_train_gpu.sh # Launch training in gpu
| |
├─src ├─src
| ├─bgcf.py # BGCF model | ├─bgcf.py # BGCF model
@ -131,8 +149,9 @@ Parameters for both training and evaluation can be set in config.py.
"gnew_neighs": 20, # Num of sampling neighbors in sample graph "gnew_neighs": 20, # Num of sampling neighbors in sample graph
"input_dim": 64, # User and item embedding dimension "input_dim": 64, # User and item embedding dimension
"l2": 0.03 # l2 coefficient "l2": 0.03 # l2 coefficient
"neighbor_dropout": [0.0, 0.2, 0.3]# Dropout ratio for different aggregation layer "neighbor_dropout": [0.0, 0.2, 0.3] # Dropout ratio for different aggregation layer
``` ```
config.py for more configuration. config.py for more configuration.
## [Training Process](#contents) ## [Training Process](#contents)
@ -140,6 +159,7 @@ Parameters for both training and evaluation can be set in config.py.
### Training ### Training
- running on Ascend - running on Ascend
```python ```python
sh run_train_ascend.sh sh run_train_ascend.sh
``` ```
@ -161,11 +181,28 @@ Parameters for both training and evaluation can be set in config.py.
... ...
``` ```
- running on GPU
```python
sh run_train_gpu.sh 0 dataset_path
```
Training result will be stored in the scripts path, whose folder name begins with "train". You can find the result like the
followings in log.
```python
Epoch 001 iter 12 loss 34696.242
Epoch 002 iter 12 loss 34275.508
Epoch 003 iter 12 loss 30620.635
Epoch 004 iter 12 loss 21628.908
```
## [Evaluation Process](#contents) ## [Evaluation Process](#contents)
### Evaluation ### Evaluation
- Evaluation on Ascend - Evaluation on Ascend
```python ```python
sh run_eval_ascend.sh sh run_eval_ascend.sh
``` ```
@ -190,34 +227,54 @@ Parameters for both training and evaluation can be set in config.py.
sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038 sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038
... ...
``` ```
- Evaluation on GPU
```python
sh run_eval_gpu.sh 0 dataset_path
```
Evaluation result will be stored in the scripts path, whose folder name begins with "eval". You can find the result like the
followings in log.
```python
epoch:680, recall_@10:0.10383, recall_@20:0.15524, ndcg_@10:0.07503, ndcg_@20:0.09249,
sedp_@10:0.01926, sedp_@20:0.01547, nov_@10:7.60851, nov_@20:7.81969
```
# [Model Description](#contents) # [Model Description](#contents)
## [Performance](#contents) ## [Performance](#contents)
### Evaluation Performance
| Parameter | BGCF | ### Training Performance
| ------------------------------------ | ----------------------------------------- |
| Model Version | Inception V1 | | Parameter | BGCF Ascend | BGCF GPU |
| Resource | Ascend 910 | | ------------------------------ | ------------------------------------------ | ------------------------------------------ |
| uploaded Date | 09/23/2020(month/day/year) | | Model Version | Inception V1 | Inception V1 |
| MindSpore Version | 1.0.0 | | Resource | Ascend 910 | Tesla V100-PCIE |
| Dataset | Amazon-Beauty | | uploaded Date | 09/23/2020(month/day/year) | 01/27/2021(month/day/year) |
| Training Parameter | epoch=600,steps=12,batch_size=5000,lr=0.001 | | MindSpore Version | 1.0.0 | 1.1.0 |
| Optimizer | Adam | | Dataset | Amazon-Beauty | Amazon-Beauty |
| Loss Function | BPR loss | | Training Parameter | epoch=600,steps=12,batch_size=5000,lr=0.001| epoch=680,steps=12,batch_size=5000,lr=0.001|
| Training Cost | 25min | | Optimizer | Adam | Adam |
| Scripts | [bgcf script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) | | Loss Function | BPR loss | BPR loss |
| Training Cost | 25min | 60min |
| Scripts | [bgcf script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) | [bgcf script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) |
### Inference Performance ### Inference Performance
| Parameter | BGCF |
| ------------------------------------ | ----------------------------------------- | | Parameter | BGCF Ascend | BGCF GPU |
| Model Version | Inception V1 | | ------------------------------ | ---------------------------- | ---------------------------- |
| Resource | Ascend 910 | | Model Version | Inception V1 | Inception V1 |
| uploaded Date | 09/23/2020(month/day/year) | | Resource | Ascend 910 | Tesla V100-PCIE |
| MindSpore Version | 1.0.0 | | uploaded Date | 09/23/2020(month/day/year) | 01/28/2021(month/day/year) |
| Dataset | Amazon-Beauty | | MindSpore Version | 1.0.0 | Master(4b3e53b4) |
| Batch_size | 5000 | | Dataset | Amazon-Beauty | Amazon-Beauty |
| Output | probability | | Batch_size | 5000 | 5000 |
| Recall@20 | 0.1534 | | Output | probability | probability |
| NDCG@20 | 0.0912 | | Recall@20 | 0.1534 | 0.15524 |
| NDCG@20 | 0.0912 | 0.09249 |
# [Description of random situation](#contents) # [Description of random situation](#contents)
BGCF model contains lots of dropout operations, if you want to disable dropout, set the neighbor_dropout to [0.0, 0.0, 0.0] in src/config.py. BGCF model contains lots of dropout operations, if you want to disable dropout, set the neighbor_dropout to [0.0, 0.0, 0.0] in src/config.py.
@ -225,5 +282,3 @@ BGCF model contains lots of dropout operations, if you want to disable dropout,
# [ModelZoo Homepage](#contents) # [ModelZoo Homepage](#contents)
Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo). Please check the official [homepage](http://gitee.com/mindspore/mindspore/tree/master/model_zoo).

@ -41,7 +41,7 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
所用数据集的统计信息摘要如下: 所用数据集的统计信息摘要如下:
| | Amazon-Beauty | | | Amazon-Beauty |
| ------------------ | -----------------------:| | ------------------ | ------------------ |
| 任务 | 推荐 | | 任务 | 推荐 |
| # 用户 | 7068 (1图) | | # 用户 | 7068 (1图) |
| # 物品 | 3570 | | # 物品 | 3570 |
@ -54,25 +54,32 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
- 将数据集放到任意路径文件夹应该包含如下文件以Amazon-Beauty数据集为例 - 将数据集放到任意路径文件夹应该包含如下文件以Amazon-Beauty数据集为例
```text ```text
. .
└─data └─data
├─ratings_Beauty.csv ├─ratings_Beauty.csv
``` ```
- 为Amazon-Beauty生成MindRecord格式的数据集 - 为Amazon-Beauty生成MindRecord格式的数据集
```builddoutcfg ```builddoutcfg
cd ./scripts cd ./scripts
# SRC_PATH是您下载的数据集文件路径 # SRC_PATH是您下载的数据集文件路径
sh run_process_data_ascend.sh [SRC_PATH] sh run_process_data_ascend.sh [SRC_PATH]
``` ```
- 启动 - 启动
```text ```text
# 为Amazon-Beauty生成MindRecord格式的数据集 # 为Amazon-Beauty生成MindRecord格式的数据集
sh ./run_process_data_ascend.sh ./data sh ./run_process_data_ascend.sh ./data
```
# 特性 # 特性
## 混合精度 ## 混合精度
@ -81,7 +88,7 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
# 环境要求 # 环境要求
- 硬件Ascend - 硬件Ascend/GPU
- 框架 - 框架
- [MindSpore](https://www.mindspore.cn/install) - [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源: - 如需查看详情,请参见如下资源:
@ -104,6 +111,18 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
``` ```
- GPU处理器环境运行
```text
# 使用Amazon-Beauty数据集运行训练示例
sh run_train_gpu.sh 0 dataset_path
# 使用Amazon-Beauty数据集运行评估示例
sh run_eval_gpu.sh 0 dataset_path
```
# 脚本说明 # 脚本说明
## 脚本及样例代码 ## 脚本及样例代码
@ -113,9 +132,11 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
└─bgcf └─bgcf
├─README.md ├─README.md
├─scripts ├─scripts
| ├─run_eval_ascend.sh # 启动评估 | ├─run_eval_ascend.sh # Ascend启动评估
| ├─run_eval_gpu.sh # GPU启动评估
| ├─run_process_data_ascend.sh # 生成MindRecord格式的数据集 | ├─run_process_data_ascend.sh # 生成MindRecord格式的数据集
| └─run_train_ascend.sh # 启动训练 | └─run_train_ascend.sh # Ascend启动训练
| └─run_train_gpu.sh # GPU启动训练
| |
├─src ├─src
| ├─bgcf.py # BGCF模型 | ├─bgcf.py # BGCF模型
@ -178,7 +199,25 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
Epoch 598 iter 12 loss 3640.7612 Epoch 598 iter 12 loss 3640.7612
Epoch 599 iter 12 loss 3654.9087 Epoch 599 iter 12 loss 3654.9087
Epoch 600 iter 12 loss 3632.4585 Epoch 600 iter 12 loss 3632.4585
...
```
- GPU处理器环境运行
```python
sh run_train_gpu.sh 0 dataset_path
```
训练结果将保存在脚本路径下文件夹名称以“train”开头。您可在日志中找到结果如下所示。
```python
Epoch 001 iter 12 loss 34696.242
Epoch 002 iter 12 loss 34275.508
Epoch 003 iter 12 loss 30620.635
Epoch 004 iter 12 loss 21628.908
``` ```
@ -212,7 +251,23 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
sedp_@10:0.01896, sedp_@20:0.01504, nov_@10:7.57995, nov_@20:7.79439 sedp_@10:0.01896, sedp_@20:0.01504, nov_@10:7.57995, nov_@20:7.79439
epoch:600, recall_@10:0.09926, recall_@20:0.15080, ndcg_@10:0.07283, ndcg_@20:0.09016, epoch:600, recall_@10:0.09926, recall_@20:0.15080, ndcg_@10:0.07283, ndcg_@20:0.09016,
sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038 sedp_@10:0.01890, sedp_@20:0.01517, nov_@10:7.58277, nov_@20:7.80038
...
```
- GPU评估
```python
sh run_eval_gpu.sh 0 dataset_path
```
评估结果将保存在脚本路径下文件夹名称以“eval”开头。您可在日志中找到结果如下所示。
```python
epoch:680, recall_@10:0.10383, recall_@20:0.15524, ndcg_@10:0.07503, ndcg_@20:0.09249,
sedp_@10:0.01926, sedp_@20:0.01547, nov_@10:7.60851, nov_@20:7.81969
``` ```
@ -220,19 +275,19 @@ BGCF包含两个主要模块。首先是抽样它生成基于节点复制的
## 性能 ## 性能
| 参数 | BGCF | | 参数 | BGCF Ascend | BGCF GPU |
| ------------------------------------ | ----------------------------------------- | | -------------------------- | ------------------------------------------ | ------------------------------------------ |
| 资源 | Ascend 910 | | 资源 | Ascend 910 | Tesla V100-PCIE |
| 上传日期 | 09/23/2020(月/日/年) | | 上传日期 | 09/23/2020(月/日/年) | 01/28/2021(月/日/年) |
| MindSpore版本 | 1.0.0 | | MindSpore版本 | 1.0.0 | Master(4b3e53b4) |
| 数据集 | Amazon-Beauty | | 数据集 | Amazon-Beauty | Amazon-Beauty |
| 训练参数 | epoch=600 | | 训练参数 | epoch=600,steps=12,batch_size=5000,lr=0.001| epoch=680,steps=12,batch_size=5000,lr=0.001|
| 优化器 | Adam | | 优化器 | Adam | Adam |
| 损失函数 | BPR loss | | 损失函数 | BPR loss | BPR loss |
| Recall@20 | 0.1534 | | Recall@20 | 0.1534 | 0.15524 |
| NDCG@20 | 0.0912 | | NDCG@20 | 0.0912 | 0.09249 |
| 训练成本 | 25min | | 训练成本 | 25min | 60min |
| 脚本 | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) | | 脚本 | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) | [bgcf脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/bgcf) |
# 随机情况说明 # 随机情况说明

@ -19,6 +19,7 @@ import datetime
import mindspore.context as context import mindspore.context as context
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint
from mindspore.common import set_seed
from src.bgcf import BGCF from src.bgcf import BGCF
from src.utils import BGCFLogger from src.utils import BGCFLogger
@ -27,6 +28,7 @@ from src.metrics import BGCFEvaluate
from src.callback import ForwardBGCF, TestBGCF from src.callback import ForwardBGCF, TestBGCF
from src.dataset import TestGraphDataset, load_graph from src.dataset import TestGraphDataset, load_graph
set_seed(1)
def evaluation(): def evaluation():
"""evaluation""" """evaluation"""
@ -34,7 +36,8 @@ def evaluation():
num_item = train_graph.graph_info()["node_num"][1] num_item = train_graph.graph_info()["node_num"][1]
eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks) eval_class = BGCFEvaluate(parser, train_graph, test_graph, parser.Ks)
for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval): for _epoch in range(parser.eval_interval, parser.num_epoch+1, parser.eval_interval) \
if parser.device_target == "Ascend" else range(parser.num_epoch, parser.num_epoch+1):
bgcfnet_test = BGCF([parser.input_dim, num_user, num_item], bgcfnet_test = BGCF([parser.input_dim, num_user, num_item],
parser.embedded_dimension, parser.embedded_dimension,
parser.activation, parser.activation,
@ -79,9 +82,10 @@ def evaluation():
if __name__ == "__main__": if __name__ == "__main__":
parser = parser_args() parser = parser_args()
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", device_target=parser.device_target,
save_graphs=False, save_graphs=False)
device_id=int(parser.device)) if parser.device_target == "Ascend":
context.set_context(device_id=int(parser.device))
train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath) train_graph, test_graph, sampled_graph_list = load_graph(parser.datapath)
test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs, test_graph_dataset = TestGraphDataset(train_graph, sampled_graph_list, num_samples=parser.raw_neighs,

@ -0,0 +1,47 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
ulimit -u unlimited
if [ $# -lt 2 ]
then
echo "Usage: sh run_eval_gpu.sh [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi
export DEVICE_NUM=1
DATASET_PATH=$2
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start evaluation"
export CUDA_VISIBLE_DEVICES="$1"
python eval.py --datapath=$DATASET_PATH --ckptpath=../ckpts \
--device_target='GPU' --num_epoch=680 \
--dist_reg=0 > log 2>&1 &
cd ..

@ -0,0 +1,51 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -lt 2 ]
then
echo "Usage: sh run_train_gpu.sh [VISIABLE_DEVICES(0,1,2,3,4,5,6,7)] [DATASET_PATH]"
exit 1
fi
export DEVICE_NUM=1
DATASET_PATH=$2
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
if [ -d "ckpts" ];
then
rm -rf ./ckpts
fi
mkdir ./ckpts
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
env > env.log
echo "start training"
export CUDA_VISIBLE_DEVICES="$1"
python train.py --datapath=$DATASET_PATH --ckptpath=../ckpts \
--device_target='GPU' --num_epoch=680 \
--dist_reg=0 > log 2>&1 &
cd ..

@ -49,4 +49,5 @@ def parser_args():
parser.add_argument("-emb", "--embedded_dimension", type=int, default=64, help="output embedding dim") parser.add_argument("-emb", "--embedded_dimension", type=int, default=64, help="output embedding dim")
parser.add_argument('--dist_reg', type=float, default=0.003, help="distance loss coefficient") parser.add_argument('--dist_reg', type=float, default=0.003, help="distance loss coefficient")
parser.add_argument('--device_target', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='device target')
return parser.parse_args() return parser.parse_args()

@ -21,6 +21,7 @@ from mindspore import Tensor
import mindspore.context as context import mindspore.context as context
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.train.serialization import save_checkpoint from mindspore.train.serialization import save_checkpoint
from mindspore.common import set_seed
from src.bgcf import BGCF from src.bgcf import BGCF
from src.config import parser_args from src.config import parser_args
@ -28,6 +29,7 @@ from src.utils import convert_item_id
from src.callback import TrainBGCF from src.callback import TrainBGCF
from src.dataset import load_graph, create_dataset from src.dataset import load_graph, create_dataset
set_seed(1)
def train(): def train():
"""Train""" """Train"""
@ -102,10 +104,13 @@ def train():
if __name__ == "__main__": if __name__ == "__main__":
parser = parser_args() parser = parser_args()
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", device_target=parser.device_target,
save_graphs=False, save_graphs=False)
device_id=int(parser.device))
if parser.device_target == "Ascend":
context.set_context(device_id=int(parser.device))
train_graph, _, sampled_graph_list = load_graph(parser.datapath) train_graph, _, sampled_graph_list = load_graph(parser.datapath)
train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs, train_ds = create_dataset(train_graph, sampled_graph_list, parser.workers, batch_size=parser.batch_pairs,

Loading…
Cancel
Save