diff --git a/model_zoo/official/nlp/gru/README.md b/model_zoo/official/nlp/gru/README.md index ed8a1a196c..1e2f99a7b0 100644 --- a/model_zoo/official/nlp/gru/README.md +++ b/model_zoo/official/nlp/gru/README.md @@ -1,4 +1,4 @@ -![](https://www.mindspore.cn/static/img/logo.a3e472c9.png) +![](https://www.mindspore.cn/static/img/logo_black.6a5c850d.png) @@ -52,6 +52,26 @@ In this model, we use the Multi30K dataset as our train and test dataset.As trai - [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html) - [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html) +## Requirements + +```txt +nltk +numpy +``` + +To install nltk, you should install nltk as follow: + +```bash +pip install nltk +``` + +Then you should download extra packages as follow: + +```python +import nltk +nltk.download() +``` + # [Quick Start](#content) After dataset preparation, you can start training and evaluation as follows: diff --git a/model_zoo/official/nlp/gru/eval.py b/model_zoo/official/nlp/gru/eval.py index 0893f2e160..6fc8cb083d 100644 --- a/model_zoo/official/nlp/gru/eval.py +++ b/model_zoo/official/nlp/gru/eval.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """Transformer evaluation script.""" - +import os import argparse import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor @@ -41,8 +41,13 @@ def run_gru_eval(): context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, reserve_class_name_in_scope=False, \ device_id=args.device_id, save_graphs=False) + prefix = "multi30k_test_mindrecord_32" + mindrecord_file = os.path.join(args.dataset_path, prefix) + if not os.path.exists(mindrecord_file): + print("dataset file {} not exists, please check!".format(mindrecord_file)) + raise ValueError(mindrecord_file) dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.eval_batch_size, \ - dataset_path=args.dataset_path, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) + dataset_path=mindrecord_file, rank_size=args.device_num, rank_id=0, do_shuffle=False, is_training=False) dataset_size = dataset.get_dataset_size() print("dataset size is {}".format(dataset_size)) network = Seq2Seq(config, is_training=False) diff --git a/model_zoo/official/nlp/gru/scripts/run_distribute_train_ascend.sh b/model_zoo/official/nlp/gru/scripts/run_distribute_train_ascend.sh index c9280d7e48..04830cc32b 100644 --- a/model_zoo/official/nlp/gru/scripts/run_distribute_train_ascend.sh +++ b/model_zoo/official/nlp/gru/scripts/run_distribute_train_ascend.sh @@ -40,9 +40,9 @@ fi DATASET_PATH=$(get_real_path $2) echo $DATASET_PATH -if [ ! -f $DATASET_PATH ] +if [ ! -d $DATASET_PATH ] then - echo "error: DATASET_PATH=$DATASET_PATH is not a file" + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" exit 1 fi diff --git a/model_zoo/official/nlp/gru/scripts/run_eval.sh b/model_zoo/official/nlp/gru/scripts/run_eval.sh index 5b0f7defed..599d934fa8 100644 --- a/model_zoo/official/nlp/gru/scripts/run_eval.sh +++ b/model_zoo/official/nlp/gru/scripts/run_eval.sh @@ -41,9 +41,9 @@ fi DATASET_PATH=$(get_real_path $2) echo $DATASET_PATH -if [ ! -f $DATASET_PATH ] +if [ ! -d $DATASET_PATH ] then - echo "error: DATASET_PATH=$DATASET_PATH is not a file" + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" exit 1 fi rm -rf ./eval diff --git a/model_zoo/official/nlp/gru/scripts/run_standalone_train.sh b/model_zoo/official/nlp/gru/scripts/run_standalone_train.sh index 0c5cca1461..b2fbc878af 100644 --- a/model_zoo/official/nlp/gru/scripts/run_standalone_train.sh +++ b/model_zoo/official/nlp/gru/scripts/run_standalone_train.sh @@ -33,9 +33,9 @@ get_real_path(){ DATASET_PATH=$(get_real_path $1) echo $DATASET_PATH -if [ ! -f $DATASET_PATH ] +if [ ! -d $DATASET_PATH ] then - echo "error: DATASET_PATH=$DATASET_PATH is not a file" + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" exit 1 fi diff --git a/model_zoo/official/nlp/gru/train.py b/model_zoo/official/nlp/gru/train.py index 998d506cc3..d4f362a073 100644 --- a/model_zoo/official/nlp/gru/train.py +++ b/model_zoo/official/nlp/gru/train.py @@ -99,8 +99,13 @@ if __name__ == '__main__': else: rank = 0 device_num = 1 + prefix = "multi30k_train_mindrecord_32_" + mindrecord_file = os.path.join(args.dataset_path, prefix+"0") + if not os.path.exists(mindrecord_file): + print("dataset file {} not exists, please check!".format(mindrecord_file)) + raise ValueError(mindrecord_file) dataset = create_gru_dataset(epoch_count=config.num_epochs, batch_size=config.batch_size, - dataset_path=args.dataset_path, rank_size=device_num, rank_id=rank) + dataset_path=mindrecord_file, rank_size=device_num, rank_id=rank) dataset_size = dataset.get_dataset_size() print("dataset size is {}".format(dataset_size)) network = Seq2Seq(config)