You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
qujianwei 19fc6306d3
fix for dataset input
4 years ago
scripts fix for dataset input 4 years ago
src fix for dataset input 4 years ago fix for dataset input 4 years ago fix for dataset input 4 years ago


GRU(Gate Recurrent Unit) is a kind of recurrent neural network algorithm, just like the LSTM(Long-Short Term Memory). It was proposed by Kyunghyun Cho, Bart van Merrienboer etc. in the article "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" in 2014. In this paper, it proposes a novel neural network model called RNN Encoder-Decoder that consists of two recurrent neural networks (RNN).To improve the effect of translation task, we also refer to "Sequence to Sequence Learning with Neural Networks" and "Neural Machine Translation by Jointly Learning to Align and Translate".


1.Paper: "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation", 2014, Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio

2.Paper: "Sequence to Sequence Learning with Neural Networks", 2014, Ilya Sutskever, Oriol Vinyals, Quoc V. Le

3.Paper: "Neural Machine Translation by Jointly Learning to Align and Translate", 2014, Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio

Model Structure

The GRU model mainly consists of an Encoder and a Decoder.The Encoder is constructed with a bidirection GRU cell.The Decoder mainly contains an attention and a GRU cell.The input of the net is sequence of words (text or sentence), and the output of the net is the probability of each word in vocab, and we choose the maximum probability one as our prediction.


In this model, we use the Multi30K dataset as our train and test dataset.As training dataset, it provides 29,000 respectively, each containing an German sentence and its English translation.For testing dataset, it provides 1000 German and English sentences.We also provide a preprocess script to tokenize the dataset and create the vocab file.

Environment Requirements



To install nltk, you should install nltk as follow:

pip install nltk

Then you should download extra packages as follow:

import nltk

Quick Start

After dataset preparation, you can start training and evaluation as follows:

# run training example
cd ./scripts

# run distributed training example

# run evaluation example

Script Description

The GRU network script and code result are as follows:

├── gru
  ├──                              // Introduction of GRU model.
  ├── src
  |   ├──                              // gru cell architecture.
  │   ├──                           // Configuration instance definition.
  │   ├──                      // Dataset preparation.
  │   ├──                          // Dataset loader to feed into model.
  │   ├──                    // GRU eval model architecture.
  │   ├──                    // GRU train model architecture.
  │   ├──                             // Loss architecture.
  │   ├──                      // Learning rate scheduler.
  │   ├──                     // Parse output file.
  │   ├──                       // Dataset preprocess.
  │   ├──                          // Seq2seq architecture.
  │   ├──                     // tokenization for the dataset.
  │   ├──                      // Initialize weights in the net.
  ├── scripts
  │   ├──                   // shell script for create dataset.
  │   ├──                     // shell script for parse eval output file to calculate BLEU.
  │   ├──                       // shell script for preprocess dataset.
  │   ├──            // shell script for distributed train on ascend.
  │   ├──                         // shell script for standalone eval on ascend.
  │   ├──             // shell script for standalone eval on ascend.
  ├──                                // Infer API entry.
  ├── requirements.txt                       // Requirements of third party package.
  ├──                               // Train API entry.

Dataset Preparation

Firstly, we should download the dataset from the WMT16 official net.After downloading the Multi30k dataset file, we get six dataset file, which is show as below.And we should in put the in same directory.

Then, we can use the scripts/ to tokenize the dataset file and get the vocab file.


After preprocess, we will get the dataset file which is suffix with ".tok" and two vocab file, which are nameed and vocab.en. Then we provided scripts/ to create the dataset file which format is mindrecord.


Finally, we will get multi30k_train_mindrecord_0 ~ multi30k_train_mindrecord_8 as our train dataset, and multi30k_test_mindrecord as our test dataset.

Configuration File

Parameters for both training and evaluation can be set in All the datasets are using same parameter name, parameters value could be changed according the needs.

  • Network Parameters

      "batch_size": 16,                  # batch size of input dataset.
      "src_vocab_size": 8154,            # source dataset vocabulary size.
      "trg_vocab_size": 6113,            # target dataset vocabulary size.
      "encoder_embedding_size": 256,     # encoder embedding size.
      "decoder_embedding_size": 256,     # decoder embedding size.
      "hidden_size": 512,                # hidden size of gru.
      "max_length": 32,                  # max sentence length.
      "num_epochs": 30,                  # total epoch.
      "save_checkpoint": True,           # whether save checkpoint file.
      "ckpt_epoch": 1,                   # frequence to save checkpoint file.
      "target_file": "target.txt",       # the target file.
      "output_file": "output.txt",       # the output file.
      "keep_checkpoint_max": 30,         # the maximum number of checkpoint file.
      "base_lr": 0.001,                  # init learning rate.
      "warmup_step": 300,                # warmup step.
      "momentum": 0.9,                   # momentum in optimizer.
      "init_loss_scale_value": 1024,     # init scale sense.
      'scale_factor': 2,                 # scale factor for dynamic loss scale.
      'scale_window': 2000,              # scale window for dynamic loss scale.
      "warmup_ratio": 1/3.0,             # warmup ratio.
      "teacher_force_ratio": 0.5         # teacher force ratio.

Training Process

  • Start task training on a single device and run the shell script

    cd ./scripts
  • Running scripts for distributed training of GRU. Task training on multiple device and run the following command in bash to be executed in scripts/:

    cd ./scripts

Inference Process

  • Running scripts for evaluation of GRU. The commdan as below.

    cd ./scripts
  • After evalulation, we will get eval/target.txt and eval/output.txt.Then we can use scripts/ to get the translation.

    cp eval/*.txt ./
    sh target.txt output.txt /path/vocab.en
  • After parse output, we will get target.txt.forbleu and output.txt.forbleu.To calculate BLEU score, you may use this perl script and run following command to get the BLEU score.

    perl multi-bleu.perl target.txt.forbleu < output.txt.forbleu

Note: The DATASET_PATH is path to mindrecord. eg. train: /dataset_path/multi30k_train_mindrecord_0 eval: /dataset_path/multi30k_test_mindrecord

Model Description


Training Performance

Parameters Ascend
Resource Ascend 910
uploaded Date 01/18/2021 (month/day/year)
MindSpore Version 1.1.0
Dataset Multi30k Dataset
Training Parameters epoch=30, batch_size=16
Optimizer Adam
Loss Function NLLLoss
outputs probability
Speed 50ms/step (1pcs)
Epoch Time 13.4s (1pcs)
Loss 2.5984
Params (M) 21
Checkpoint for inference 272M (.ckpt file)
Scripts gru

Inference Performance

Parameters Ascend
Resource Ascend 910
Uploaded Date 01/18/2020 (month/day/year)
MindSpore Version 1.1.0
Dataset Multi30K
batch_size 1
outputs label index
Accuracy BLEU: 30.30
Model for inference 272M (.ckpt file)

Random Situation Description

There only one random situation.

  • Initialization of some model weights.

Some seeds have already been set in to avoid the randomness of weight initialization.


This model has been validated in the Ascend environment and is not validated on the CPU and GPU.

ModelZoo HomePage

Please check the official homepage