You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/official/gnn/gat/README_CN.md

203 lines
7.3 KiB

# 目录
<!-- TOC -->
- [目录](#目录)
- [图注意力网络描述](#图注意力网络描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [训练](#训练)
- [模型描述](#模型描述)
- [性能](#性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#modelzoo主页)
<!-- /TOC -->
# 图注意力网络描述
图注意力网络GAT由Petar Veličković等人于2017年提出。GAT通过利用掩蔽自注意层来克服现有基于图的方法的缺点在Cora等传感数据集和PPI等感应数据集上都达到了最先进的性能。以下是用MindSpore的Cora数据集训练GAT的例子。
[论文](https://arxiv.org/abs/1710.10903): Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., & Bengio, Y. (2017).Graph attention networks. arXiv preprint arXiv:1710.10903.
# 模型架构
请注意节点更新函数是级联还是平均,取决于注意力层是否为网络输出层。
# 数据集
- 数据集大小:
所用数据集汇总如下:
| | Cora | Citeseer |
| ------------------ | -------------: | -------------: |
| 任务 | Transductive | Transductive |
| # 节点 | 2708 (1图) | 3327 (1图) |
| # 边 | 5429 | 4732 |
| # 特性/节点 | 1433 | 3703 |
| # 类 | 7 | 6 |
| # 训练节点 | 140 | 120 |
| # 验证节点 | 500 | 500 |
| # 测试节点 | 1000 | 1000 |
- 数据准备
- 将数据集放到任意路径文件夹应该包含如下文件以Cora数据集为例
```text
.
└─data
├─ind.cora.allx
├─ind.cora.ally
├─ind.cora.graph
├─ind.cora.test.index
├─ind.cora.tx
├─ind.cora.ty
├─ind.cora.x
└─ind.cora.y
```
- 为Cora或Citeseer生成MindRecord格式的数据集
```buildoutcfg
cd ./scripts
# SRC_PATH为下载的数据集文件路径DATASET_NAME为Cora或Citeseer
sh run_process_data_ascend.sh [SRC_PATH] [DATASET_NAME]
```
- 启动
```text
# 为Cora生成MindRecord格式的数据集
./run_process_data_ascend.sh ./data cora
# 为Citeseer生成MindRecord格式的数据集
./run_process_data_ascend.sh ./data citeseer
```
# 特性
## 混合精度
为了充分利用Ascend芯片强大的运算能力加快训练过程此处采用混合训练方法。MindSpore能够处理FP32输入和FP16操作符。在GAT示例中除损失计算部分外模型设置为FP16模式。
# 环境要求
- 硬件Ascend
- 框架
- [MindSpore](https://www.mindspore.cn/install)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore并正确生成数据集后您可以按照如下步骤进行训练和评估
- Ascend处理器环境运行
```text
# 使用Cora数据集运行训练示例DATASET_NAME为cora
sh run_train_ascend.sh [DATASET_NAME]
```
# 脚本说明
## 脚本及样例代码
```shell
.
└─gat
├─README.md
├─scripts
| ├─run_process_data_ascend.sh # 生成MindRecord格式的数据集
| └─run_train_ascend.sh # 启动训练
|
├─src
| ├─config.py # 训练配置
| ├─dataset.py # 数据预处理
| ├─gat.py # GAT模型
| └─utils.py # 训练gat的工具
|
└─train.py # 训练网络
```
## 脚本参数
在config.py中可以同时配置训练参数和评估参数。
- 配置GAT和Cora数据集
```python
"learning_rate": 0.005, # 学习率
"num_epochs": 200, # 训练轮次
"hid_units": [8], # 每层注意头隐藏单元
"n_heads": [8, 1], # 每层头数
"early_stopping": 100, # 早停忍耐轮次数
"l2_coeff": 0.0005 # l2系数
"attn_dropout": 0.6 # 注意力层dropout系数
"feature_dropout":0.6 # 特征层dropout系数
```
## 训练过程
### 训练
- Ascend处理器环境运行
```python
sh run_train_ascend.sh [DATASET_NAME]
```
训练结果将保存在脚本路径下文件夹名称以“train”开头。您可在日志中找到结果
,如下所示。
```python
Epoch:0, train loss=1.98498 train acc=0.17143 | val loss=1.97946 val acc=0.27200
Epoch:1, train loss=1.98345 train acc=0.15000 | val loss=1.97233 val acc=0.32600
Epoch:2, train loss=1.96968 train acc=0.21429 | val loss=1.96747 val acc=0.37400
Epoch:3, train loss=1.97061 train acc=0.20714 | val loss=1.96410 val acc=0.47600
Epoch:4, train loss=1.96864 train acc=0.13571 | val loss=1.96066 val acc=0.59600
...
Epoch:195, train loss=1.45111 train_acc=0.56429 | val_loss=1.44325 val_acc=0.81200
Epoch:196, train loss=1.52476 train_acc=0.52143 | val_loss=1.43871 val_acc=0.81200
Epoch:197, train loss=1.35807 train_acc=0.62857 | val_loss=1.43364 val_acc=0.81400
Epoch:198, train loss=1.47566 train_acc=0.51429 | val_loss=1.42948 val_acc=0.81000
Epoch:199, train loss=1.56411 train_acc=0.55000 | val_loss=1.42632 val_acc=0.80600
Test loss=1.5366285, test acc=0.84199995
...
```
# 模型描述
## 性能
| 参数 | GAT |
| ------------------------------------ | ----------------------------------------- |
| 资源 | Ascend 910 |
| 上传日期 | 2020-06-16 |
| MindSpore版本 | 0.5.0-beta |
| 数据集 | Cora/Citeseer |
| 训练参数 | epoch=200 |
| 优化器 | Adam |
| 损失函数 | Softmax交叉熵 |
| 准确率 | 83.0/72.5 |
| 速度 | 0.195s/epoch |
| 总时长 | 39s |
| 脚本 | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/gnn/gat> |
# 随机情况说明
GAT模型中有很多的dropout操作如果想关闭dropout可以在src/config.py中将attn_dropout和feature_dropout设置为0。注该操作会导致准确率降低到80%左右。
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。