From 9103ff5d753fa33808a9eb1b38f4a564467d2e86 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Dec 2020 15:55:29 +0800 Subject: [PATCH] lstm D network --- model_zoo/official/nlp/lstm/README.md | 181 ++++++---- model_zoo/official/nlp/lstm/README_CN.md | 326 ++++++++++++++++++ model_zoo/official/nlp/lstm/eval.py | 34 +- .../nlp/lstm/script/run_eval_ascend.sh | 39 +++ .../official/nlp/lstm/script/run_eval_cpu.sh | 2 +- .../official/nlp/lstm/script/run_eval_gpu.sh | 2 +- .../nlp/lstm/script/run_train_ascend.sh | 39 +++ .../official/nlp/lstm/script/run_train_cpu.sh | 2 +- .../official/nlp/lstm/script/run_train_gpu.sh | 2 +- model_zoo/official/nlp/lstm/src/config.py | 22 ++ .../official/nlp/lstm/src/lr_schedule.py | 60 ++++ model_zoo/official/nlp/lstm/src/lstm.py | 158 ++++++++- model_zoo/official/nlp/lstm/train.py | 34 +- 13 files changed, 823 insertions(+), 78 deletions(-) create mode 100644 model_zoo/official/nlp/lstm/README_CN.md create mode 100644 model_zoo/official/nlp/lstm/script/run_eval_ascend.sh create mode 100644 model_zoo/official/nlp/lstm/script/run_train_ascend.sh create mode 100644 model_zoo/official/nlp/lstm/src/lr_schedule.py diff --git a/model_zoo/official/nlp/lstm/README.md b/model_zoo/official/nlp/lstm/README.md index 7b2fe1c02d..bf1e807cec 100644 --- a/model_zoo/official/nlp/lstm/README.md +++ b/model_zoo/official/nlp/lstm/README.md @@ -1,3 +1,4 @@ +[查看中文](./README_CN.md) # Contents - [LSTM Description](#lstm-description) @@ -18,7 +19,6 @@ - [Description of Random Situation](#description-of-random-situation) - [ModelZoo Homepage](#modelzoo-homepage) - # [LSTM Description](#contents) This example is for LSTM model training and evaluation. @@ -29,26 +29,35 @@ This example is for LSTM model training and evaluation. LSTM contains embeding, encoder and decoder modules. Encoder module consists of LSTM layer. Decoder module consists of fully-connection layer. - # [Dataset](#contents) + Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below. - aclImdb_v1 for training evaluation.[Large Movie Review Dataset](http://ai.stanford.edu/~amaas/data/sentiment/) - GloVe: Vector representations for words.[GloVe: Global Vectors for Word Representation](https://nlp.stanford.edu/projects/glove/) - # [Environment Requirements](#contents) -- Hardware(GPU/CPU) +- Hardware(GPU/CPU/Ascend) + - If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you could get the resources for trial. - Framework - - [MindSpore](https://gitee.com/mindspore/mindspore) + - [MindSpore](https://gitee.com/mindspore/mindspore) - For more information, please check the resources below: - - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) - + - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) # [Quick Start](#contents) +- runing on Ascend + + ```bash + # run training example + bash run_train_ascend.sh 0 ./aclimdb ./glove_dir + + # run evaluation example + bash run_eval_ascend.sh 0 ./preprocess lstm-20_390.ckpt + ``` + - runing on GPU ```bash @@ -69,7 +78,6 @@ Note that you can run the scripts based on the dataset mentioned in original pap bash run_eval_cpu.sh ./aclimdb ./glove_dir lstm-20_390.ckpt ``` - # [Script Description](#contents) ## [Script and Sample Code](#contents) @@ -80,19 +88,21 @@ Note that you can run the scripts based on the dataset mentioned in original pap    ├── README.md # descriptions about LSTM    ├── script    │   ├── run_eval_gpu.sh # shell script for evaluation on GPU +    │   ├── run_eval_ascend.sh # shell script for evaluation on Ascend    │   ├── run_eval_cpu.sh # shell script for evaluation on CPU    │   ├── run_train_gpu.sh # shell script for training on GPU +    │   ├── run_train_ascend.sh # shell script for training on Ascend    │   └── run_train_cpu.sh # shell script for training on CPU    ├── src    │   ├── config.py # parameter configuration    │   ├── dataset.py # dataset preprocess    │   ├── imdb.py # imdb dataset read script +    │   ├── lr_schedule.py # dynamic_lr script    │   └── lstm.py # Sentiment model -    ├── eval.py # evaluation script on both GPU and CPU -    └── train.py # training script on both GPU and CPU +    ├── eval.py # evaluation script on GPU, CPU and Ascend +    └── train.py # training script on GPU, CPU and Ascend ``` - ## [Script Parameters](#contents) ### Training Script Parameters @@ -101,7 +111,7 @@ Note that you can run the scripts based on the dataset mentioned in original pap usage: train.py [-h] [--preprocess {true, false}] [--aclimdb_path ACLIMDB_PATH] [--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH] [--ckpt_path CKPT_PATH] [--pre_trained PRE_TRAINING] - [--device_target {GPU, CPU}] + [--device_target {GPU, CPU, Ascend}] Mindspore LSTM Example @@ -113,15 +123,16 @@ options: --preprocess_path PREPROCESS_PATH # path where the pre-process data is stored. --ckpt_path CKPT_PATH # the path to save the checkpoint file. --pre_trained # the pretrained checkpoint file path. - --device_target # the target device to run, support "GPU", "CPU". Default: "GPU". + --device_target # the target device to run, support "GPU", "CPU", "Ascend". Default: "Ascend". ``` - ### Running Options ```python config.py: +GPU/CPU: num_classes # classes num + dynamic_lr # if use dynamic learning rate learning_rate # value of learning rate momentum # value of momentum num_epochs # epoch size @@ -131,42 +142,81 @@ config.py: num_layers # number of layers of stacked LSTM bidirectional # specifies whether it is a bidirectional LSTM save_checkpoint_steps # steps for saving checkpoint files + +Ascend: + num_classes # classes num + momentum # value of momentum + num_epochs # epoch size + batch_size # batch size of input dataset + embed_size # the size of each embedding vector + num_hiddens # number of features of hidden layer + num_layers # number of layers of stacked LSTM + bidirectional # specifies whether it is a bidirectional LSTM + save_checkpoint_steps # steps for saving checkpoint files + keep_checkpoint_max # max num of checkpoint files + dynamic_lr # if use dynamic learning rate + lr_init # init learning rate of Dynamic learning rate + lr_end # end learning rate of Dynamic learning rate + lr_max # max learning rate of Dynamic learning rate + lr_adjust_epoch # Dynamic learning rate adjust epoch + warmup_epochs # warmup epochs + global_step # global step ``` ### Network Parameters - ## [Dataset Preparation](#contents) + - Download the dataset aclImdb_v1. -> Unzip the aclImdb_v1 dataset to any path you want and the folder structure should be as follows: -> ``` -> . -> ├── train # train dataset -> └── test # infer dataset -> ``` + Unzip the aclImdb_v1 dataset to any path you want and the folder structure should be as follows: + + ```bash + . + ├── train # train dataset + └── test # infer dataset + ``` - Download the GloVe file. -> Unzip the glove.6B.zip to any path you want and the folder structure should be as follows: -> ``` -> . -> ├── glove.6B.100d.txt -> ├── glove.6B.200d.txt -> ├── glove.6B.300d.txt # we will use this one later. -> └── glove.6B.50d.txt -> ``` - -> Adding a new line at the beginning of the file which named `glove.6B.300d.txt`. -> It means reading a total of 400,000 words, each represented by a 300-latitude word vector. -> ``` -> 400000 300 -> ``` + Unzip the glove.6B.zip to any path you want and the folder structure should be as follows: + + ```bash + . + ├── glove.6B.100d.txt + ├── glove.6B.200d.txt + ├── glove.6B.300d.txt # we will use this one later. + └── glove.6B.50d.txt + ``` + + Adding a new line at the beginning of the file which named `glove.6B.300d.txt`. + It means reading a total of 400,000 words, each represented by a 300-latitude word vector. + + ```bash + 400000 300 + ``` ## [Training Process](#contents) - Set options in `config.py`, including learning rate and network hyperparameters. +- runing on Ascend + + Run `sh run_train_ascend.sh` for training. + + ``` bash + bash run_train_ascend.sh 0 ./aclimdb ./glove_dir + ``` + + The above shell script will train in the background. You will get the loss value as following: + + ```shell + # grep "loss is " log.txt + epoch: 1 step: 390, loss is 0.6003723 + epcoh: 2 step: 390, loss is 0.35312173 + ... + ``` + - runing on GPU Run `sh run_train_gpu.sh` for training. @@ -176,6 +226,7 @@ config.py: ``` The above shell script will run distribute training in the background. You will get the loss value as following: + ```shell # grep "loss is " log.txt epoch: 1 step: 390, loss is 0.6003723 @@ -200,9 +251,16 @@ config.py: ... ``` - ## [Evaluation Process](#contents) +- evaluation on Ascend + + Run `bash run_eval_ascend.sh` for evaluation. + + ``` bash + bash run_eval_ascend.sh 0 ./preprocess lstm-20_390.ckpt + ``` + - evaluation on GPU Run `bash run_eval_gpu.sh` for evaluation. @@ -220,45 +278,44 @@ config.py: ``` # [Model Description](#contents) + ## [Performance](#contents) ### Training Performance -| Parameters | LSTM (GPU) | LSTM (CPU) | -| -------------------------- | -------------------------------------------------------------- | -------------------------- | -| Resource | Tesla V100-SMX2-16GB | Ubuntu X86-i7-8565U-16GB | -| uploaded Date | 10/28/2020 (month/day/year) | 10/28/2020 (month/day/year)| -| MindSpore Version | 1.0.0 | 1.0.0 | -| Dataset | aclimdb_v1 | aclimdb_v1 | -| Training Parameters | epoch=20, batch_size=64 | epoch=20, batch_size=64 | -| Optimizer | Momentum | Momentum | -| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy | -| Speed | 1022 (1pcs) | 20 | -| Loss | 0.12 | 0.12 | -| Params (M) | 6.45 | 6.45 | -| Checkpoint for inference | 292.9M (.ckpt file) | 292.9M (.ckpt file) | -| Scripts | [lstm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm) | [lstm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm) | - +| Parameters | LSTM (Ascend) | LSTM (GPU) | LSTM (CPU) | +| -------------------------- | -------------------------- | -------------------------------------------------------------- | -------------------------- | +| Resource | Ascend 910 | Tesla V100-SMX2-16GB | Ubuntu X86-i7-8565U-16GB | +| uploaded Date | 12/21/2020 (month/day/year)| 10/28/2020 (month/day/year) | 10/28/2020 (month/day/year)| +| MindSpore Version | 1.0.0 | 1.0.0 | 1.0.0 | +| Dataset | aclimdb_v1 | aclimdb_v1 | aclimdb_v1 | +| Training Parameters | epoch=20, batch_size=64 | epoch=20, batch_size=64 | epoch=20, batch_size=64 | +| Optimizer | Momentum | Momentum | Momentum | +| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy | Softmax Cross Entropy | +| Speed | 1097 | 1022 (1pcs) | 20 | +| Loss | 0.12 | 0.12 | 0.12 | +| Params (M) | 6.45 | 6.45 | 6.45 | +| Checkpoint for inference | 292.9M (.ckpt file) | 292.9M (.ckpt file) | 292.9M (.ckpt file) | +| Scripts | [lstm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm) | [lstm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm) | [lstm script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm) | ### Evaluation Performance -| Parameters | LSTM (GPU) | LSTM (CPU) | -| ------------------- | --------------------------- | ---------------------------- | -| Resource | Tesla V100-SMX2-16GB | Ubuntu X86-i7-8565U-16GB | -| uploaded Date | 10/28/2020 (month/day/year) | 10/28/2020 (month/day/year) | -| MindSpore Version | 1.0.0 | 1.0.0 | -| Dataset | aclimdb_v1 | aclimdb_v1 | -| batch_size | 64 | 64 | -| Accuracy | 84% | 83% | - +| Parameters | LSTM (Ascend) | LSTM (GPU) | LSTM (CPU) | +| ------------------- | ---------------------------- | --------------------------- | ---------------------------- | +| Resource | Ascend 910 | Tesla V100-SMX2-16GB | Ubuntu X86-i7-8565U-16GB | +| uploaded Date | 12/21/2020 (month/day/year) | 10/28/2020 (month/day/year) | 10/28/2020 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.0.0 | 1.0.0 | +| Dataset | aclimdb_v1 | aclimdb_v1 | aclimdb_v1 | +| batch_size | 64 | 64 | 64 | +| Accuracy | 85% | 84% | 83% | # [Description of Random Situation](#contents) There are three random situations: + - Shuffle of the dataset. - Initialization of some model weights. - # [ModelZoo Homepage](#contents) Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/nlp/lstm/README_CN.md b/model_zoo/official/nlp/lstm/README_CN.md new file mode 100644 index 0000000000..2768ed3a2d --- /dev/null +++ b/model_zoo/official/nlp/lstm/README_CN.md @@ -0,0 +1,326 @@ +[View English](./README.md) +# 目录 + + +- [目录](#目录) +- [LSTM概述](#lstm概述) +- [模型架构](#模型架构) +- [数据集](#数据集) +- [环境要求](#环境要求) +- [快速入门](#快速入门) +- [脚本说明](#脚本说明) + - [脚本和样例代码](#脚本和样例代码) + - [脚本参数](#脚本参数) + - [训练脚本参数](#训练脚本参数) + - [运行选项](#运行选项) + - [网络参数](#网络参数) + - [准备数据集](#准备数据集) + - [训练过程](#训练过程) + - [评估过程](#评估过程) +- [模型描述](#模型描述) + - [性能](#性能) + - [训练性能](#训练性能) + - [评估性能](#评估性能) +- [随机情况说明](#随机情况说明) +- [ModelZoo主页](#modelzoo主页) + + + +# LSTM概述 + +本示例用于LSTM模型训练和评估。 + +[论文](https://www.aclweb.org/anthology/P11-1015/): Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, Christopher Potts。[面向情绪分析学习词向量](https://www.aclweb.org/anthology/P11-1015/),Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies.2011 + +# 模型架构 + +LSTM模型包含嵌入层、编码器和解码器这几个模块,编码器模块由LSTM层组成,解码器模块由全连接层组成。 + +# 数据集 + +- aclImdb_v1用于训练评估。[大型电影评论数据集](http://ai.stanford.edu/~amaas/data/sentiment/) +- 单词表示形式的全局矢量(GloVe):用于单词的向量表示。[GloVe](https://nlp.stanford.edu/projects/glove/) + +# 环境要求 + +- 硬件(GPU/CPU/Ascend) + - 如果你想尝试Ascend,请发送[Ascend Model Zoo体验资源申请表](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx)到ascend@huawei.com申请Ascend体验资源。 +- 框架 + - [MindSpore](https://www.mindspore.cn/install) +- 更多关于Mindspore的信息,请查看以下资源: + - [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html) + +# 快速入门 + +- 在Ascend处理器上运行 + + ```bash + # 运行训练示例 + bash run_train_ascend.sh 0 ./aclimdb ./glove_dir + + # 运行评估示例 + bash run_eval_ascend.sh 0 ./preprocess lstm-20_390.ckpt + ``` + +- 在GPU处理器上运行 + + ```bash + # 运行训练示例 + bash run_train_gpu.sh 0 ./aclimdb ./glove_dir + + # 运行评估示例 + bash run_eval_gpu.sh 0 ./aclimdb ./glove_dir lstm-20_390.ckpt + ``` + +- 在CPU处理器上运行 + + ```bash + # 运行训练示例 + bash run_train_cpu.sh ./aclimdb ./glove_dir + + # 运行评估示例 + bash run_eval_cpu.sh ./aclimdb ./glove_dir lstm-20_390.ckpt + ``` + +# 脚本说明 + +## 脚本和样例代码 + +```shell +. +├── lstm +    ├── README.md # LSTM相关说明 +    ├── script +    │   ├── run_eval_ascend.sh # Ascend评估的shell脚本 +    │   ├── run_eval_gpu.sh # GPU评估的shell脚本 +    │   ├── run_eval_cpu.sh # CPU评估shell脚本 +    │   ├── run_train_ascend.sh # Ascend训练的shell脚本 +    │   ├── run_train_gpu.sh # GPU训练的shell脚本 +    │   └── run_train_cpu.sh # CPU训练的shell脚本 +    ├── src +    │   ├── config.py # 参数配置 +    │   ├── dataset.py # 数据集预处理 +    │   ├── imdb.py # IMDB数据集读脚本 +    │   ├── lr_schedule.py # 动态学习率脚步 +    │   └── lstm.py # 情感模型 +    ├── eval.py # GPU、CPU和Ascend的评估脚本 +    └── train.py # GPU、CPU和Ascend的训练脚本 +``` + +## 脚本参数 + +### 训练脚本参数 + +```python +用法:train.py [-h] [--preprocess {true, false}] [--aclimdb_path ACLIMDB_PATH] + [--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH] + [--ckpt_path CKPT_PATH] [--pre_trained PRE_TRAINING] + [--device_target {GPU, CPU, Ascend}] + +Mindspore LSTM示例 + +选项: + -h, --help # 显示此帮助信息并退出 + --preprocess {true, false} # 是否进行数据预处理 + --aclimdb_path ACLIMDB_PATH # 数据集所在路径 + --glove_path GLOVE_PATH # GloVe工具所在路径 + --preprocess_path PREPROCESS_PATH # 预处理数据存放路径 + --ckpt_path CKPT_PATH # 检查点文件保存路径 + --pre_trained # 预训练的checkpoint文件路径 + --device_target # 待运行的目标设备,支持GPU、CPU、Ascend。默认值:"Ascend"。 +``` + +### 运行选项 + +```python +config.py: +GPU/CPU: + num_classes # 类别数 + dynamic_lr # 是否使用动态学习率 + learning_rate # 学习率 + momentum # 动量 + num_epochs # 轮次 + batch_size # 输入数据集的批次大小 + embed_size # 每个嵌入向量的大小 + num_hiddens # 隐藏层特征数 + num_layers # 栈式LSTM的层数 + bidirectional # 是否双向LSTM + save_checkpoint_steps # 保存检查点文件的步数 + +Ascend: + num_classes # 类别数 + momentum # 动量 + num_epochs # 轮次 + batch_size # 输入数据集的批次大小 + embed_size # 每个嵌入向量的大小 + num_hiddens # 隐藏层特征数 + num_layers # 栈式LSTM的层数 + bidirectional # 是否双向LSTM + save_checkpoint_steps # 保存检查点文件的步数 + keep_checkpoint_max # 最多保存ckpt个数 + dynamic_lr # 是否使用动态学习率 + lr_init # 动态学习率的起始学习率 + lr_end # 动态学习率的最终学习率 + lr_max # 动态学习率的最大学习率 + lr_adjust_epoch # 动态学习率在此epoch范围内调整 + warmup_epochs # warmup的epoch数 + global_step # 全局步数 +``` + +### 网络参数 + +## 准备数据集 + +- 下载aclImdb_v1数据集。 + + 将aclImdb_v1数据集解压到任意路径,文件夹结构如下: + + ```bash + . + ├── train # 训练数据集 + └── test # 推理数据集 + ``` + +- 下载GloVe文件。 + + 将glove.6B.zip解压到任意路径,文件夹结构如下: + + ```bash + . + ├── glove.6B.100d.txt + ├── glove.6B.200d.txt + ├── glove.6B.300d.txt # 后续会用到这个文件 + └── glove.6B.50d.txt + ``` + + 在`glove.6B.300d.txt`文件开头增加一行。 + 用来读取40万个单词,每个单词由300纬度的词向量来表示。 + + ```bash + 400000 300 + ``` + +## 训练过程 + +- 在`config.py`中设置选项,包括loss_scale、学习率和网络超参。 + +- 运行在Ascend处理器上 + + 执行`sh run_train_ascend.sh`进行训练。 + + ``` bash + bash run_train_ascend.sh 0 ./aclimdb ./glove_dir + ``` + + 上述shell脚本在后台执行训练,得到如下损失值: + + ```shell + # grep "loss is " log.txt + epoch: 1 step: 390, loss is 0.6003723 + epcoh: 2 step: 390, loss is 0.35312173 + ... + ``` + +- 在GPU处理器上运行 + + 执行`sh run_train_gpu.sh`进行训练。 + + ``` bash + bash run_train_gpu.sh 0 ./aclimdb ./glove_dir + ``` + + 上述shell脚本在后台运行分布式训练,得到如下损失值: + + ```shell + # grep "loss is " log.txt + epoch: 1 step: 390, loss is 0.6003723 + epcoh: 2 step: 390, loss is 0.35312173 + ... + ``` + +- 运行在CPU处理器上 + + 执行`sh run_train_cpu.sh`进行训练。 + + ``` bash + bash run_train_cpu.sh ./aclimdb ./glove_dir + ``` + + 上述shell脚本在后台执行训练,得到如下损失值: + + ```shell + # grep "loss is " log.txt + epoch: 1 step: 390, loss is 0.6003723 + epcoh: 2 step: 390, loss is 0.35312173 + ... + ``` + +## 评估过程 + +- 在Ascend处理器上进行评估 + + 执行`bash run_eval_ascend.sh`进行评估。 + + ``` bash + bash run_eval_ascend.sh 0 ./preprocess lstm-20_390.ckpt + ``` + +- 在GPU处理器上进行评估 + + 执行`bash run_eval_gpu.sh`进行评估。 + + ``` bash + bash run_eval_gpu.sh 0 ./aclimdb ./glove_dir lstm-20_390.ckpt + ``` + +- 在CPU处理器上进行评估 + + 执行`bash run_eval_cpu.sh`进行评估。 + + ``` bash + bash run_eval_cpu.sh 0 ./aclimdb ./glove_dir lstm-20_390.ckpt + ``` + +# 模型描述 + +## 性能 + +### 训练性能 + +| 参数 | LSTM (Ascend) | LSTM (GPU) | LSTM (CPU) | +| -------------------------- | -------------------------- | -------------------------------------------------------------- | -------------------------- | +| 资源 | Ascend 910 | Tesla V100-SMX2-16GB | Ubuntu X86-i7-8565U-16GB | +| 上传日期 | 2020-12-21 | 2020-08-06 | 2020-08-06 | +| MindSpore版本 | 1.0.0 | 0.6.0-beta | 0.6.0-beta | +| 数据集 | aclimdb_v1 | aclimdb_v1 | aclimdb_v1 | +| 训练参数 | epoch=20, batch_size=64 | epoch=20, batch_size=64 | epoch=20, batch_size=64 | +| 优化器 | Momentum | Momentum | Momentum | +| 损失函数 | SoftmaxCrossEntropy | SoftmaxCrossEntropy | SoftmaxCrossEntropy | +| 速度 | 1097 | 1022(单卡) | 20 | +| 损失 | 0.12 | 0.12 | 0.12 | +| 参数(M) | 6.45 | 6.45 | 6.45 | +| 推理检查点 | 292.9M(.ckpt文件) | 292.9M(.ckpt文件) | 292.9M(.ckpt文件) | +| 脚本 | [LSTM脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm) | [LSTM脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm) | [LSTM脚本](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/lstm) | + +### 评估性能 + +| 参数 | LSTM (Ascend) | LSTM (GPU) | LSTM (CPU) | +| ------------------- | ---------------------------- | --------------------------- | ---------------------------- | +| 资源 | Ascend 910 | Tesla V100-SMX2-16GB | Ubuntu X86-i7-8565U-16GB | +| 上传日期 | 2020-12-21 | 2020-08-06 | 2020-08-06 | +| MindSpore版本 | 1.0.0 | 0.6.0-beta | 0.6.0-beta | +| 数据集 | aclimdb_v1 | aclimdb_v1 | aclimdb_v1 | +| batch_size | 64 | 64 | 64 | +| 准确率 | 85% | 84% | 83% | + +# 随机情况说明 + +随机情况如下: + +- 轮换数据集。 +- 初始化部分模型权重。 + +# ModelZoo主页 + +请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。 diff --git a/model_zoo/official/nlp/lstm/eval.py b/model_zoo/official/nlp/lstm/eval.py index f9960ce55f..f7163fa0de 100644 --- a/model_zoo/official/nlp/lstm/eval.py +++ b/model_zoo/official/nlp/lstm/eval.py @@ -20,8 +20,9 @@ import os import numpy as np -from src.config import lstm_cfg as cfg +from src.config import lstm_cfg as cfg, lstm_cfg_ascend from src.dataset import lstm_create_dataset, convert_to_mindrecord +from src.lr_schedule import get_lr from src.lstm import SentimentNet from mindspore import Tensor, nn, Model, context from mindspore.nn import Accuracy @@ -40,8 +41,8 @@ if __name__ == '__main__': help='path where the pre-process data is stored.') parser.add_argument('--ckpt_path', type=str, default=None, help='the checkpoint file path used to evaluate model.') - parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], - help='the target device to run, support "GPU", "CPU". Default: "GPU".') + parser.add_argument('--device_target', type=str, default="Ascend", choices=['GPU', 'CPU', 'Ascend'], + help='the target device to run, support "GPU", "CPU". Default: "Ascend".') args = parser.parse_args() context.set_context( @@ -49,11 +50,24 @@ if __name__ == '__main__': save_graphs=False, device_target=args.device_target) + if args.device_target == 'Ascend': + cfg = lstm_cfg_ascend + else: + cfg = lstm_cfg + if args.preprocess == "true": print("============== Starting Data Pre-processing ==============") convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) + # DynamicRNN in this network on Ascend platform only support the condition that the shape of input_size + # and hiddle_size is multiples of 16, this problem will be solved later. + if args.device_target == 'Ascend': + pad_num = int(np.ceil(cfg.embed_size / 16) * 16 - cfg.embed_size) + if pad_num > 0: + embedding_table = np.pad(embedding_table, [(0, 0), (0, pad_num)], 'constant') + cfg.embed_size = int(np.ceil(cfg.embed_size / 16) * 16) + network = SentimentNet(vocab_size=embedding_table.shape[0], embed_size=cfg.embed_size, num_hiddens=cfg.num_hiddens, @@ -64,13 +78,23 @@ if __name__ == '__main__': batch_size=cfg.batch_size) loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) + ds_eval = lstm_create_dataset(args.preprocess_path, cfg.batch_size, training=False) + if cfg.dynamic_lr: + lr = Tensor(get_lr(global_step=cfg.global_step, + lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max, + warmup_epochs=cfg.warmup_epochs, + total_epochs=cfg.num_epochs, + steps_per_epoch=ds_eval.get_dataset_size(), + lr_adjust_epoch=cfg.lr_adjust_epoch)) + else: + lr = cfg.learning_rate + + opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum) loss_cb = LossMonitor() model = Model(network, loss, opt, {'acc': Accuracy()}) print("============== Starting Testing ==============") - ds_eval = lstm_create_dataset(args.preprocess_path, cfg.batch_size, training=False) param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) if args.device_target == "CPU": diff --git a/model_zoo/official/nlp/lstm/script/run_eval_ascend.sh b/model_zoo/official/nlp/lstm/script/run_eval_ascend.sh new file mode 100644 index 0000000000..23f752cd16 --- /dev/null +++ b/model_zoo/official/nlp/lstm/script/run_eval_ascend.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_eval_ascend.sh DEVICE_ID PREPROCESS_DIR CKPT_FILE" +echo "for example: bash run_eval_ascend.sh 0 ./preprocess lstm-20_390.ckpt" +echo "==============================================================================================================" + +DEVICE_ID=$1 +PREPROCESS_DIR=$2 +CKPT_FILE=$3 + +rm -rf eval +mkdir -p eval +cd eval +mkdir -p ms_log +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +export DEVICE_ID=$DEVICE_ID +python ../../eval.py \ + --device_target="Ascend" \ + --preprocess=false \ + --preprocess_path=$PREPROCESS_DIR \ + --ckpt_path=$CKPT_FILE > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/lstm/script/run_eval_cpu.sh b/model_zoo/official/nlp/lstm/script/run_eval_cpu.sh index e9740e1b90..2ffe41c2e3 100644 --- a/model_zoo/official/nlp/lstm/script/run_eval_cpu.sh +++ b/model_zoo/official/nlp/lstm/script/run_eval_cpu.sh @@ -15,7 +15,7 @@ # ============================================================================ echo "==============================================================================================================" -echo "Please run the scipt as: " +echo "Please run the script as: " echo "bash run_eval_cpu.sh ACLIMDB_DIR GLOVE_DIR CKPT_FILE" echo "for example: bash run_eval_cpu.sh ./aclimdb ./glove_dir lstm-20_390.ckpt" echo "==============================================================================================================" diff --git a/model_zoo/official/nlp/lstm/script/run_eval_gpu.sh b/model_zoo/official/nlp/lstm/script/run_eval_gpu.sh index 9fc99f53ca..e2fa176f0f 100644 --- a/model_zoo/official/nlp/lstm/script/run_eval_gpu.sh +++ b/model_zoo/official/nlp/lstm/script/run_eval_gpu.sh @@ -15,7 +15,7 @@ # ============================================================================ echo "==============================================================================================================" -echo "Please run the scipt as: " +echo "Please run the script as: " echo "bash run_train_gpu.sh DEVICE_ID ACLIMDB_DIR GLOVE_DIR CKPT_FILE" echo "for example: bash run_train_gpu.sh 0 ./aclimdb ./glove_dir lstm-20_390.ckpt" echo "==============================================================================================================" diff --git a/model_zoo/official/nlp/lstm/script/run_train_ascend.sh b/model_zoo/official/nlp/lstm/script/run_train_ascend.sh new file mode 100644 index 0000000000..9a52f24d93 --- /dev/null +++ b/model_zoo/official/nlp/lstm/script/run_train_ascend.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +echo "==============================================================================================================" +echo "Please run the script as: " +echo "bash run_train_ascend.sh DEVICE_ID ACLIMDB_DIR GLOVE_DIR" +echo "for example: bash run_train_ascend.sh 0 ./aclimdb ./glove_dir" +echo "==============================================================================================================" + +DEVICE_ID=$1 +ACLIMDB_DIR=$2 +GLOVE_DIR=$3 + +mkdir -p train +cd train +mkdir -p ms_log +CUR_DIR=`pwd` +export GLOG_log_dir=${CUR_DIR}/ms_log +export GLOG_logtostderr=0 +export DEVICE_ID=$DEVICE_ID +python ../../train.py \ + --device_target="Ascend" \ + --aclimdb_path=$ACLIMDB_DIR \ + --glove_path=$GLOVE_DIR \ + --preprocess=true \ + --preprocess_path=./preprocess > log.txt 2>&1 & diff --git a/model_zoo/official/nlp/lstm/script/run_train_cpu.sh b/model_zoo/official/nlp/lstm/script/run_train_cpu.sh index 26b3a422cd..6d871deb8c 100644 --- a/model_zoo/official/nlp/lstm/script/run_train_cpu.sh +++ b/model_zoo/official/nlp/lstm/script/run_train_cpu.sh @@ -15,7 +15,7 @@ # ============================================================================ echo "==============================================================================================================" -echo "Please run the scipt as: " +echo "Please run the script as: " echo "bash run_train_cpu.sh ACLIMDB_DIR GLOVE_DIR" echo "for example: bash run_train_gpu.sh ./aclimdb ./glove_dir" echo "==============================================================================================================" diff --git a/model_zoo/official/nlp/lstm/script/run_train_gpu.sh b/model_zoo/official/nlp/lstm/script/run_train_gpu.sh index 79f52580a1..df29234566 100644 --- a/model_zoo/official/nlp/lstm/script/run_train_gpu.sh +++ b/model_zoo/official/nlp/lstm/script/run_train_gpu.sh @@ -15,7 +15,7 @@ # ============================================================================ echo "==============================================================================================================" -echo "Please run the scipt as: " +echo "Please run the script as: " echo "bash run_train_gpu.sh DEVICE_ID ACLIMDB_DIR GLOVE_DIR" echo "for example: bash run_train_gpu.sh 0 ./aclimdb ./glove_dir" echo "==============================================================================================================" diff --git a/model_zoo/official/nlp/lstm/src/config.py b/model_zoo/official/nlp/lstm/src/config.py index 688760111c..741ab045e1 100644 --- a/model_zoo/official/nlp/lstm/src/config.py +++ b/model_zoo/official/nlp/lstm/src/config.py @@ -20,6 +20,7 @@ from easydict import EasyDict as edict # LSTM CONFIG lstm_cfg = edict({ 'num_classes': 2, + 'dynamic_lr': False, 'learning_rate': 0.1, 'momentum': 0.9, 'num_epochs': 20, @@ -31,3 +32,24 @@ lstm_cfg = edict({ 'save_checkpoint_steps': 390, 'keep_checkpoint_max': 10 }) + +# LSTM CONFIG IN ASCEND +lstm_cfg_ascend = edict({ + 'num_classes': 2, + 'momentum': 0.9, + 'num_epochs': 20, + 'batch_size': 64, + 'embed_size': 300, + 'num_hiddens': 128, + 'num_layers': 2, + 'bidirectional': True, + 'save_checkpoint_steps': 7800, + 'keep_checkpoint_max': 10, + 'dynamic_lr': True, + 'lr_init': 0.05, + 'lr_end': 0.01, + 'lr_max': 0.1, + 'lr_adjust_epoch': 6, + 'warmup_epochs': 1, + 'global_step': 0 +}) diff --git a/model_zoo/official/nlp/lstm/src/lr_schedule.py b/model_zoo/official/nlp/lstm/src/lr_schedule.py new file mode 100644 index 0000000000..131877859d --- /dev/null +++ b/model_zoo/official/nlp/lstm/src/lr_schedule.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Learning rate schedule""" + +import math +import numpy as np + + +def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_adjust_epoch): + """ + generate learning rate array + + Args: + global_step(int): total steps of the training + lr_init(float): init learning rate + lr_end(float): end learning rate + lr_max(float): max learning rate + warmup_epochs(float): number of warmup epochs + total_epochs(int): total epoch of training + steps_per_epoch(int): steps of one epoch + lr_adjust_epoch(int): lr adjust in lr_adjust_epoch, after that, the lr is lr_end + + Returns: + np.array, learning rate array + """ + lr_each_step = [] + total_steps = steps_per_epoch * total_epochs + warmup_steps = steps_per_epoch * warmup_epochs + adjust_steps = lr_adjust_epoch * steps_per_epoch + for i in range(total_steps): + if i < warmup_steps: + lr = lr_init + (lr_max - lr_init) * i / warmup_steps + elif i < adjust_steps: + lr = lr_end + \ + (lr_max - lr_end) * \ + (1. + math.cos(math.pi * (i - warmup_steps) / (adjust_steps - warmup_steps))) / 2. + else: + lr = lr_end + if lr < 0.0: + lr = 0.0 + lr_each_step.append(lr) + + current_step = global_step + lr_each_step = np.array(lr_each_step).astype(np.float32) + learning_rate = lr_each_step[current_step:] + + return learning_rate diff --git a/model_zoo/official/nlp/lstm/src/lstm.py b/model_zoo/official/nlp/lstm/src/lstm.py index 5ee90b8ad2..8a21f4ebde 100644 --- a/model_zoo/official/nlp/lstm/src/lstm.py +++ b/model_zoo/official/nlp/lstm/src/lstm.py @@ -20,6 +20,8 @@ import numpy as np from mindspore import Tensor, nn, context, Parameter, ParameterTuple from mindspore.common.initializer import initializer from mindspore.ops import operations as P +import mindspore.ops.functional as F +import mindspore.common.dtype as mstype STACK_LSTM_DEVICE = ["CPU"] @@ -44,6 +46,28 @@ def stack_lstm_default_state(batch_size, hidden_size, num_layers, bidirectional) h, c = tuple(h_list), tuple(c_list) return h, c +def stack_lstm_default_state_ascend(batch_size, hidden_size, num_layers, bidirectional): + """init default input.""" + + h_list = c_list = [] + for _ in range(num_layers): + h_fw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16)) + c_fw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16)) + h_i = [h_fw] + c_i = [c_fw] + + if bidirectional: + h_bw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16)) + c_bw = Tensor(np.zeros((1, batch_size, hidden_size)).astype(np.float16)) + h_i.append(h_bw) + c_i.append(c_bw) + + h_list.append(h_i) + c_list.append(c_i) + + h, c = tuple(h_list), tuple(c_list) + return h, c + class StackLSTM(nn.Cell): """ @@ -114,6 +138,128 @@ class StackLSTM(nn.Cell): x = self.transpose(x, (1, 0, 2)) return x, (hn, cn) +class LSTM_Ascend(nn.Cell): + """ LSTM in Ascend. """ + + def __init__(self, bidirectional=False): + super(LSTM_Ascend, self).__init__() + self.bidirectional = bidirectional + self.dynamic_rnn = P.DynamicRNN(forget_bias=0.0) + self.reverseV2 = P.ReverseV2(axis=[0]) + self.concat = P.Concat(2) + + def construct(self, x, h, c, w_f, b_f, w_b=None, b_b=None): + """construct""" + x = F.cast(x, mstype.float16) + if self.bidirectional: + y1, h1, c1, _, _, _, _, _ = self.dynamic_rnn(x, w_f, b_f, None, h[0], c[0]) + r_x = self.reverseV2(x) + y2, h2, c2, _, _, _, _, _ = self.dynamic_rnn(r_x, w_b, b_b, None, h[1], c[1]) + y2 = self.reverseV2(y2) + + output = self.concat((y1, y2)) + hn = self.concat((h1, h2)) + cn = self.concat((c1, c2)) + return output, (hn, cn) + + y1, h1, c1, _, _, _, _, _ = self.dynamic_rnn(x, w_f, b_f, None, h[0], c[0]) + return y1, (h1, c1) + +class StackLSTMAscend(nn.Cell): + """ Stack multi-layers LSTM together. """ + + def __init__(self, + input_size, + hidden_size, + num_layers=1, + has_bias=True, + batch_first=False, + dropout=0.0, + bidirectional=False): + super(StackLSTMAscend, self).__init__() + self.num_layers = num_layers + self.batch_first = batch_first + self.bidirectional = bidirectional + self.transpose = P.Transpose() + + # input_size list + input_size_list = [input_size] + for i in range(num_layers - 1): + input_size_list.append(hidden_size * 2) + + #weights, bias and layers init + weights_fw = [] + weights_bw = [] + bias_fw = [] + bias_bw = [] + + stdv = 1 / math.sqrt(hidden_size) + for i in range(num_layers): + # forward weight init + w_np_fw = np.random.uniform(-stdv, + stdv, + (input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float16) + w_fw = Parameter(initializer(Tensor(w_np_fw), w_np_fw.shape), name="w_fw_layer" + str(i)) + weights_fw.append(w_fw) + # forward bias init + if has_bias: + b_fw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float16) + b_fw = Parameter(initializer(Tensor(b_fw), b_fw.shape), name="b_fw_layer" + str(i)) + else: + b_fw = np.zeros((hidden_size * 4)).astype(np.float16) + b_fw = Parameter(initializer(Tensor(b_fw), b_fw.shape), name="b_fw_layer" + str(i)) + bias_fw.append(b_fw) + + if bidirectional: + # backward weight init + w_np_bw = np.random.uniform(-stdv, + stdv, + (input_size_list[i] + hidden_size, hidden_size * 4)).astype(np.float16) + w_bw = Parameter(initializer(Tensor(w_np_bw), w_np_bw.shape), name="w_bw_layer" + str(i)) + weights_bw.append(w_bw) + + # backward bias init + if has_bias: + b_bw = np.random.uniform(-stdv, stdv, (hidden_size * 4)).astype(np.float16) + b_bw = Parameter(initializer(Tensor(b_bw), b_bw.shape), name="b_bw_layer" + str(i)) + else: + b_bw = np.zeros((hidden_size * 4)).astype(np.float16) + b_bw = Parameter(initializer(Tensor(b_bw), b_bw.shape), name="b_bw_layer" + str(i)) + bias_bw.append(b_bw) + + # layer init + self.lstm = LSTM_Ascend(bidirectional=bidirectional) + + self.weight_fw = ParameterTuple(tuple(weights_fw)) + self.weight_bw = ParameterTuple(tuple(weights_bw)) + self.bias_fw = ParameterTuple(tuple(bias_fw)) + self.bias_bw = ParameterTuple(tuple(bias_bw)) + + def construct(self, x, hx): + """construct""" + x = F.cast(x, mstype.float16) + if self.batch_first: + x = self.transpose(x, (1, 0, 2)) + # stack lstm + h, c = hx + hn = cn = None + for i in range(self.num_layers): + if self.bidirectional: + x, (hn, cn) = self.lstm(x, + h[i], + c[i], + self.weight_fw[i], + self.bias_fw[i], + self.weight_bw[i], + self.bias_bw[i]) + else: + x, (hn, cn) = self.lstm(x, h[i], c[i], self.weight_fw[i], self.bias_fw[i]) + if self.batch_first: + x = self.transpose(x, (1, 0, 2)) + x = F.cast(x, mstype.float32) + hn = F.cast(x, mstype.float32) + cn = F.cast(x, mstype.float32) + return x, (hn, cn) class SentimentNet(nn.Cell): """Sentiment network structure.""" @@ -145,7 +291,7 @@ class SentimentNet(nn.Cell): bidirectional=bidirectional, dropout=0.0) self.h, self.c = stack_lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) - else: + elif context.get_context("device_target") == "GPU": # standard lstm self.encoder = nn.LSTM(input_size=embed_size, hidden_size=num_hiddens, @@ -154,8 +300,16 @@ class SentimentNet(nn.Cell): bidirectional=bidirectional, dropout=0.0) self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional) + else: + self.encoder = StackLSTMAscend(input_size=embed_size, + hidden_size=num_hiddens, + num_layers=num_layers, + has_bias=True, + bidirectional=bidirectional) + self.h, self.c = stack_lstm_default_state_ascend(batch_size, num_hiddens, num_layers, bidirectional) self.concat = P.Concat(1) + self.squeeze = P.Squeeze(axis=0) if bidirectional: self.decoder = nn.Dense(num_hiddens * 4, num_classes) else: @@ -167,6 +321,6 @@ class SentimentNet(nn.Cell): embeddings = self.trans(embeddings, self.perm) output, _ = self.encoder(embeddings, (self.h, self.c)) # states[i] size(64,200) -> encoding.size(64,400) - encoding = self.concat((output[0], output[499])) + encoding = self.concat((self.squeeze(output[0:1:1]), self.squeeze(output[499:500:1]))) outputs = self.decoder(encoding) return outputs diff --git a/model_zoo/official/nlp/lstm/train.py b/model_zoo/official/nlp/lstm/train.py index a52b7cc29e..97f058c2df 100644 --- a/model_zoo/official/nlp/lstm/train.py +++ b/model_zoo/official/nlp/lstm/train.py @@ -20,9 +20,10 @@ import os import numpy as np -from src.config import lstm_cfg as cfg +from src.config import lstm_cfg, lstm_cfg_ascend from src.dataset import convert_to_mindrecord from src.dataset import lstm_create_dataset +from src.lr_schedule import get_lr from src.lstm import SentimentNet from mindspore import Tensor, nn, Model, context from mindspore.nn import Accuracy @@ -43,8 +44,8 @@ if __name__ == '__main__': help='the path to save the checkpoint file.') parser.add_argument('--pre_trained', type=str, default=None, help='the pretrained checkpoint file path.') - parser.add_argument('--device_target', type=str, default="GPU", choices=['GPU', 'CPU'], - help='the target device to run, support "GPU", "CPU". Default: "GPU".') + parser.add_argument('--device_target', type=str, default="Ascend", choices=['GPU', 'CPU', 'Ascend'], + help='the target device to run, support "GPU", "CPU". Default: "Ascend".') args = parser.parse_args() context.set_context( @@ -52,11 +53,23 @@ if __name__ == '__main__': save_graphs=False, device_target=args.device_target) + if args.device_target == 'Ascend': + cfg = lstm_cfg_ascend + else: + cfg = lstm_cfg + if args.preprocess == "true": print("============== Starting Data Pre-processing ==============") convert_to_mindrecord(cfg.embed_size, args.aclimdb_path, args.preprocess_path, args.glove_path) embedding_table = np.loadtxt(os.path.join(args.preprocess_path, "weight.txt")).astype(np.float32) + # DynamicRNN in this network on Ascend platform only support the condition that the shape of input_size + # and hiddle_size is multiples of 16, this problem will be solved later. + if args.device_target == 'Ascend': + pad_num = int(np.ceil(cfg.embed_size / 16) * 16 - cfg.embed_size) + if pad_num > 0: + embedding_table = np.pad(embedding_table, [(0, 0), (0, pad_num)], 'constant') + cfg.embed_size = int(np.ceil(cfg.embed_size / 16) * 16) network = SentimentNet(vocab_size=embedding_table.shape[0], embed_size=cfg.embed_size, num_hiddens=cfg.num_hiddens, @@ -69,14 +82,25 @@ if __name__ == '__main__': if args.pre_trained: load_param_into_net(network, load_checkpoint(args.pre_trained)) + ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum) + if cfg.dynamic_lr: + lr = Tensor(get_lr(global_step=cfg.global_step, + lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max, + warmup_epochs=cfg.warmup_epochs, + total_epochs=cfg.num_epochs, + steps_per_epoch=ds_train.get_dataset_size(), + lr_adjust_epoch=cfg.lr_adjust_epoch)) + else: + lr = cfg.learning_rate + + opt = nn.Momentum(network.trainable_params(), lr, cfg.momentum) loss_cb = LossMonitor() model = Model(network, loss, opt, {'acc': Accuracy()}) print("============== Starting Training ==============") - ds_train = lstm_create_dataset(args.preprocess_path, cfg.batch_size, 1) config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix="lstm", directory=args.ckpt_path, config=config_ck)