!11632 fix gru net usablity bug

From: @qujianwei
Reviewed-by: @linqingke,@liangchenghui
Signed-off-by: @liangchenghui
pull/11632/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 74c2b957f7

@ -1,4 +1,4 @@
![](https://www.mindspore.cn/static/img/logo.a3e472c9.png) ![](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png)
<!-- TOC --> <!-- TOC -->
@ -52,6 +52,26 @@ In this model, we use the Multi30K dataset as our train and test dataset.As trai
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## Requirements
```txt
nltk
numpy
```
To install nltk, you should install nltk as follow:
```bash
pip install nltk
```
Then you should download extra packages as follow:
```python
import nltk
nltk.download()
```
# [Quick Start](#content) # [Quick Start](#content)
After dataset preparation, you can start training and evaluation as follows: After dataset preparation, you can start training and evaluation as follows:

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Transformer evaluation script.""" """Transformer evaluation script."""
import os
import argparse import argparse
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
@ -41,8 +41,13 @@ def run_gru_eval():
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \
device_id=args.device_id, save_graphs=False) device_id=args.device_id, save_graphs=False)
prefix = "multi30k_test_mindrecord_32"
mindrecord_file = os.path.join(args.dataset_path, prefix)
if not os.path.exists(mindrecord_file):
print("dataset file {} not exists, please check!".format(mindrecord_file))
raise ValueError(mindrecord_file)
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \ dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \
dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False)
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print("dataset size is {}".format(dataset_size)) print("dataset size is {}".format(dataset_size))
network = Seq2Seq(config, is_training=False) network = Seq2Seq(config, is_training=False)

@ -40,9 +40,9 @@ fi
DATASET_PATH=$(get_real_path $2) DATASET_PATH=$(get_real_path $2)
echo $DATASET_PATH echo $DATASET_PATH
if [ ! -f $DATASET_PATH ] if [ ! -d $DATASET_PATH ]
then then
echo "error: DATASET_PATH=$DATASET_PATH is not a file" echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1 exit 1
fi fi

@ -41,9 +41,9 @@ fi
DATASET_PATH=$(get_real_path $2) DATASET_PATH=$(get_real_path $2)
echo $DATASET_PATH echo $DATASET_PATH
if [ ! -f $DATASET_PATH ] if [ ! -d $DATASET_PATH ]
then then
echo "error: DATASET_PATH=$DATASET_PATH is not a file" echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1 exit 1
fi fi
rm -rf ./eval rm -rf ./eval

@ -33,9 +33,9 @@ get_real_path(){
DATASET_PATH=$(get_real_path $1) DATASET_PATH=$(get_real_path $1)
echo $DATASET_PATH echo $DATASET_PATH
if [ ! -f $DATASET_PATH ] if [ ! -d $DATASET_PATH ]
then then
echo "error: DATASET_PATH=$DATASET_PATH is not a file" echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1 exit 1
fi fi

@ -99,8 +99,13 @@ if __name__ == '__main__':
else: else:
rank = 0 rank = 0
device_num = 1 device_num = 1
prefix = "multi30k_train_mindrecord_32_"
mindrecord_file = os.path.join(args.dataset_path, prefix+"0")
if not os.path.exists(mindrecord_file):
print("dataset file {} not exists, please check!".format(mindrecord_file))
raise ValueError(mindrecord_file)
dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size, dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size,
dataset_path=args.dataset_path, rank_size=device_num, rank_id=rank) dataset_path=mindrecord_file, rank_size=device_num, rank_id=rank)
dataset_size = dataset.get_dataset_size() dataset_size = dataset.get_dataset_size()
print("dataset size is {}".format(dataset_size)) print("dataset size is {}".format(dataset_size))
network = Seq2Seq(config) network = Seq2Seq(config)

Loading…
Cancel
Save