!8727 Clean code for gnmt_v2 network

From: @gaojing22
Reviewed-by: @yingjy,@guoqi1024
Signed-off-by: @yingjy
pull/8727/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7689062c7d

@ -12,16 +12,12 @@
- [Dataset Preparation](#dataset-preparation)
- [Configuration File](#configuration-file)
- [Training Process](#training-process)
- [Evaluation Process](#evaluation-process)
- [Inference Process](#inference-process)
- [Model Description](#model-description)
- [Performance](#performance)
- [Result](#result)
- [Training Performance](#training-performance)
- [Inference Performance](#inference-performance)
- [Practice](#practice)
- [Dataset Preprocessing](#dataset-preprocessing)
- [Training](#training-1)
- [Inference](#inference-1)
- [Random Situation Description](#random-situation-description)
- [Others](#others)
- [ModelZoo](#modelzoo)
@ -50,8 +46,8 @@ Note that you can run the scripts based on the dataset mentioned in original pap
- Framework
- Install [MindSpore](https://www.mindspore.cn/install/en).
- For more information, please check the resources below:
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/en/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/en/master/index.html)
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## Software
```txt
@ -62,18 +58,26 @@ subword_nmt==0.3.7
```
# [Quick Start](#contents)
The process of GNMTv2 performing the text translation task is as follows:
1. Download the wmt16 data corpus and extract the dataset. For details, see the chapter "_Dataset_" above.
2. Dataset preparation and configuration.
3. Training.
4. Inference.
After dataset preparation, you can start training and evaluation as follows:
```bash
# run training example
python train.py --config /home/workspace/gnmt_v2/config/config.json
cd ./scripts
sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET
# run distributed training example
cd ./scripts
sh run_distributed_train_ascend.sh
sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET
# run evaluation example
cd ./scripts
sh run_standalone_eval_ascend.sh
sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
```
# Script Description
@ -130,7 +134,7 @@ The GNMT network script and code result are as follows:
```
## Dataset Preparation
You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/Translation/GNMT/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow/Translation/GNMT/scripts/wmt16_en_de.sh) to download and preprocess WMT English-German dataset. Assuming you get the following files:
- train.tok.clean.bpe.32000.en
- train.tok.clean.bpe.32000.de
- vocab.bpe.32000
@ -146,35 +150,60 @@ You may use this [shell script](https://github.com/NVIDIA/DeepLearningExamples/b
## Configuration File
The JSON file in the `config/` directory is the template configuration file.
Almost all required options and parameters can be easily assigned, including the training platform, dataset and model configuration, and optimizer parameters. By setting the corresponding options, you can also obtain optional functions such as loss scale and checkpoint.
For more information about attributes, see the `config/config.py` file.
Almost all required options and parameters can be easily assigned, including the training platform, model configuration, and optimizer parameters.
- config for GNMTv2
```python
'random_seed': 50 # global random seed
'epochs':6 # total training epochs
'batch_size': 128 # training batch size
'dataset_sink_mode': true # whether use dataset sink mode
'seq_length': 51 # max length of source sentences
'vocab_size': 32320 # vocabulary size
'hidden_size': 125 # the output's last dimension of dynamicRNN
'initializer_range': 0.1 # initializer range
'max_decode_length': 125 # max length of decoder
'lr': 0.1 # initial learning rate
'lr_scheduler': 'WarmupMultiStepLR' # learning rate scheduler
'existed_ckpt': '' # the absolute full path to save the checkpoint file
```
For more configuration details, please refer the script `config/config.py` file.
## Training Process
The model training requires the shell script `scripts/run_standalone_train_ascend.sh`. In this script, set environment variables and the training script `train.py` to be executed in `gnmt_v2/`.
Start task training on a single device and run the following command in bash:
For a pre-trained model, configure the following options in the `scripts/run_standalone_train_ascend.json` file:
- Select an optimizer ('momentum/adam/lamb' is available).
- Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file.
- Set other parameters, including dataset configuration and network configuration.
- If a pre-trained model exists, assign `existed_ckpt` to the path of the existing model during fine-tuning.
Start task training on a single device and run the shell script `scripts/run_standalone_train_ascend.sh`:
```bash
cd ./scripts
sh run_standalone_train_ascend.sh
sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET
```
or multiple devices
In this script, the `DATASET_SCHEMA_TRAIN` and `PRE_TRAIN_DATASET` are the dataset schema and dataset address.
Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model.
Task training on multiple devices and run the following command in bash to be executed in `scripts/`.:
```bash
cd ./scripts
sh run_distributed_train_ascend.sh
sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET
```
Note: Ensure that the hccl_json file is assigned when distributed training is running.
Currently, inconsecutive device IDs are not supported in `scripts/run_distributed_train_ascend.sh`. The device ID must start from 0 in the `distribute_script/rank_table_8p.json` file.
## Evaluation Process
Set options in `config/config_test.json`. Make sure the 'existed_ckpt', 'dataset_schema' and 'test_dataset' are set to your own path.
Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running.
Currently, inconsecutive device IDs are not supported in `scripts/run_distributed_train_ascend.sh`. The device ID must start from 0 in the `RANK_TABLE_ADDR` file.
Run `scripts/run_standalone_eval_ascend.sh` to process the output token ids to get the BLEU scores.
## Inference Process
For inference using a trained model on multiple hardware platforms, such as Ascend 910.
Set options in `config/config_test.json`.
Run the shell script `scripts/run_standalone_eval_ascend.sh` to process the output token ids to get the BLEU scores.
```bash
cd ./scripts
sh run_standalone_eval_ascend.sh
sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
```
The `DATASET_SCHEMA_TEST` and the `TEST_DATASET` are the schema and address of inference dataset respectively, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process.
The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers.
# Model Description
## Performance
@ -190,8 +219,9 @@ sh run_standalone_eval_ascend.sh
| Training Parameters | epoch=6, batch_size=128 |
| Optimizer | Adam |
| Loss Function | Softmax Cross Entropy |
| BLEU Score | 24.05 |
| outputs | probability |
| Speed | 344ms/step (8pcs) |
| Total Time | 7800s (8pcs) |
| Loss | 63.35 |
| Params (M) | 613 |
| Checkpoint for inference | 1.8G (.ckpt file) |
@ -206,48 +236,10 @@ sh run_standalone_eval_ascend.sh
| MindSpore Version | 1.0.0 |
| Dataset | WMT newstest2014 |
| batch_size | 128 |
| outputs | BLEU score |
| Accuracy | BLEU= 24.05 |
## Practice
The process of GNMTv2 performing the text translation task is as follows:
1. Download the wmt16 data corpus and extract the dataset. For details, see the chapter "_Dataset_" above.
2. Dataset preprocessing.
3. Perform training.
4. Perform inference.
### Dataset Preprocessing
For a pre-trained model, configure the following options in the `config.json` file:
```
python create_dataset.py --src_folder /home/work_space/wmt16_de_en --output_folder /home/work_space/dataset_menu
```
### Training
For a pre-trained model, configure the following options in the `config/config.json` file:
- Assign `pre_train_dataset` and `dataset_schema` to the training dataset path.
- Select an optimizer ('momentum/adam/lamb' is available).
- Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file.
- Set other parameters, including dataset configuration and network configuration.
- If a pre-trained model exists, assign `existed_ckpt` to the path of the existing model during fine-tuning.
Run the shell script `run.sh`:
```bash
cd ./scripts
sh run_standalone_train_ascend.sh
```
### Inference
For inference using a trained model on multiple hardware platforms, such as GPU, Ascend 910, and Ascend 310, see [Network Migration](https://www.mindspore.cn/tutorial/en/master/advanced_use/network_migration.html).
For inference interruption, configure the following options in the `config/config.json` file:
- Assign `test_dataset` and the `dataset_schema` to the inference dataset path.
- Assign `existed_ckpt` and the `checkpoint_path` to the path of the model file generated during training.
- Set other parameters, including dataset configuration and network configuration.
Run the shell script `run.sh`:
```bash
cd ./scripts
sh run_standalone_eval_ascend.sh
```
| Total Time | 1560s |
| outputs | probability |
| Accuracy | BLEU Score= 24.05 |
| Model for inference | 1.8G (.ckpt file) |
# Random Situation Description
There are three random situations:
@ -260,4 +252,4 @@ Some seeds have already been set in train.py to avoid the randomness of dataset
This model has been validated in the Ascend environment and is not validated on the CPU and GPU.
# ModelZoo 主页
[链接](https://gitee.com/mindspore/mindspore/tree/master/mindspore/model_zoo)
[链接](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)

@ -15,6 +15,7 @@
"""Evaluation api."""
import argparse
import pickle
import os
from mindspore.common import dtype as mstype
@ -26,11 +27,17 @@ from src.dataset.tokenizer import Tokenizer
parser = argparse.ArgumentParser(description='gnmt')
parser.add_argument("--config", type=str, required=True,
help="model config json file path.")
parser.add_argument("--dataset_schema_test", type=str, required=True,
help="dataset schema for evaluation.")
parser.add_argument("--test_dataset", type=str, required=True,
help="test dataset address.")
parser.add_argument("--existed_ckpt", type=str, required=True,
help="existed checkpoint address.")
parser.add_argument("--vocab", type=str, required=True,
help="Vocabulary to use.")
parser.add_argument("--bpe_codes", type=str, required=True,
help="bpe codes to use.")
parser.add_argument("--test_tgt", type=str, required=False,
parser.add_argument("--test_tgt", type=str, required=True,
default=None,
help="data file of the test target")
parser.add_argument("--output", type=str, required=False,
@ -45,9 +52,20 @@ def get_config(config):
return config
def _check_args(config):
if not os.path.exists(config):
raise FileNotFoundError("`config` is not existed.")
if not isinstance(config, str):
raise ValueError("`config` must be type of str.")
if __name__ == '__main__':
args, _ = parser.parse_known_args()
_check_args(args.config)
_config = get_config(args.config)
_config.dataset_schema = args.dataset_schema_test
_config.test_dataset = args.test_dataset
_config.existed_ckpt = args.existed_ckpt
result = infer(_config)
with open(args.output, "wb") as f:

@ -13,14 +13,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET"
echo "for example:"
echo "sh run_distributed_train_ascend.sh \
/home/workspace/rank_table_8p.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001"
echo "It is better to use absolute path."
echo "=============================================================================================================="
RANK_TABLE_ADDR=$1
DATASET_SCHEMA_TRAIN=$2
PRE_TRAIN_DATASET=$3
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_TABLE_FILE=/home/workspace/rank_table_8p.json
export MINDSPORE_HCCL_CONFIG_PATH=/home/workspace/rank_table_8p.json
export RANK_TABLE_FILE=$RANK_TABLE_ADDR
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_ADDR
echo $RANK_TABLE_FILE
export RANK_SIZE=8
export GLOG_v=2
for((i=0;i<=7;i++));
do
@ -32,7 +49,10 @@ do
cp -r ../../config .
export RANK_ID=$i
export DEVICE_ID=$i
python ../../train.py --config /home/workspace/gnmt_v2/config/config.json > log_gnmt_network${i}.log 2>&1 &
python ../../train.py \
--config=${current_exec_path}/device${i}/config/config.json \
--dataset_schema_train=$DATASET_SCHEMA_TRAIN \
--pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 &
cd ${current_exec_path} || exit
done
cd ${current_exec_path} || exit

@ -13,10 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
echo "for example:"
echo "sh run_standalone_eval_ascend.sh \
/home/workspace/dataset_menu/newstest2014.en.json \
/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001 \
/home/workspace/gnmt_v2/gnmt-6_3452.ckpt \
/home/workspace/wmt16_de_en/vocab.bpe.32000 \
/home/workspace/wmt16_de_en/bpe.32000 \
/home/workspace/wmt16_de_en/newstest2014.de"
echo "It is better to use absolute path."
echo "=============================================================================================================="
DATASET_SCHEMA_TEST=$1
TEST_DATASET=$2
EXISTED_CKPT_PATH=$3
VOCAB_ADDR=$4
BPE_CODE_ADDR=$5
TEST_TARGET=$6
current_exec_path=$(pwd)
echo ${current_exec_path}
export DEVICE_NUM=1
export DEVICE_ID=5
export RANK_ID=0
export RANK_SIZE=1
export GLOG_v=2
if [ -d "eval" ];
then
@ -29,5 +55,12 @@ cp -r ../config ./eval
cd ./eval || exit
echo "start eval for device $DEVICE_ID"
env > env.log
python eval.py --config /home/workspace/gnmt_v2/config/config_test.json --vocab /home/workspace/wmt16_de_en/vocab.bpe.32000 --bpe_codes /home/workspace/wmt16_de_en/bpe.32000 --test_tgt /home/workspace/wmt16_de_en/newstest2014.de >log_infer.log 2>&1 &
python eval.py \
--config=${current_exec_path}/eval/config/config_test.json \
--dataset_schema_test=$DATASET_SCHEMA_TEST \
--test_dataset=$TEST_DATASET \
--existed_ckpt=$EXISTED_CKPT_PATH \
--vocab=$VOCAB_ADDR \
--bpe_codes=$BPE_CODE_ADDR \
--test_tgt=$TEST_TARGET >log_infer.log 2>&1 &
cd ..

@ -13,11 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET"
echo "for example:"
echo "sh run_standalone_train_ascend.sh \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001"
echo "It is better to use absolute path."
echo "=============================================================================================================="
DATASET_SCHEMA_TRAIN=$1
PRE_TRAIN_DATASET=$2
export DEVICE_NUM=1
export DEVICE_ID=4
export RANK_ID=0
export RANK_SIZE=1
export GLOG_v=2
current_exec_path=$(pwd)
echo ${current_exec_path}
if [ -d "train" ];
then
rm -rf ./train
@ -29,5 +44,8 @@ cp -r ../config ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --config /home/workspace/gnmt_v2/config/config.json > log_gnmt_network.log 2>&1 &
python train.py \
--config=${current_exec_path}/train/config/config.json \
--dataset_schema_train=$DATASET_SCHEMA_TRAIN \
--pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 &
cd ..

@ -103,7 +103,6 @@ class BiLingualDataLoader(DataLoader):
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
for i in range(src_len):
src_padding[i] = 1
src_length = np.array([src_len], dtype=np.int64)
# decoder inputs
decoder_input = self.padding(tgt_tokens[:-1], self.tokenizer.padding_index, self.target_max_sen_len)
# decoder outputs
@ -119,7 +118,6 @@ class BiLingualDataLoader(DataLoader):
example = {
"src": encoder_input,
"src_padding": src_padding,
"src_length": src_length,
"prev_opt": decoder_input,
"target": decoder_output,
"tgt_padding": tgt_padding
@ -133,9 +131,9 @@ class BiLingualDataLoader(DataLoader):
print(f" | Total sen = {count}.")
if self.schema_address is not None:
provlist = [count, self.source_max_sen_len, self.source_max_sen_len, 1,
provlist = [count, self.source_max_sen_len, self.source_max_sen_len,
self.target_max_sen_len, self.target_max_sen_len, self.target_max_sen_len]
columns = ["src", "src_padding", "src_length", "prev_opt", "target", "tgt_padding"]
columns = ["src", "src_padding", "prev_opt", "target", "tgt_padding"]
with open(self.schema_address, "w", encoding="utf-8") as f:
f.write("{\n")
f.write(' "datasetType":"TF",\n')
@ -196,12 +194,10 @@ class TextDataLoader(DataLoader):
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
for i in range(src_len):
src_padding[i] = 1
src_length = np.array([src_len], dtype=np.int64)
example = {
"src": encoder_input,
"src_padding": src_padding,
"src_length": src_length
"src_padding": src_padding
}
self._add_example(example)
count += 1
@ -211,8 +207,8 @@ class TextDataLoader(DataLoader):
print(f" | Total sen = {count}.")
if self.schema_address is not None:
provlist = [count, self.source_max_sen_len, self.source_max_sen_len, 1]
columns = ["src", "src_padding", "src_length"]
provlist = [count, self.source_max_sen_len, self.source_max_sen_len]
columns = ["src", "src_padding"]
with open(self.schema_address, "w", encoding="utf-8") as f:
f.write("{\n")
f.write(' "datasetType":"TF",\n')

@ -60,20 +60,20 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
ds = de.TFRecordDataset(
input_files, schema_file,
columns_list=[
"src", "src_padding", "src_length",
"src", "src_padding",
"prev_opt",
"target", "tgt_padding"
],
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
shuffle=False, num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True, num_parallel_workers=8)
ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size: {ori_dataset_size}.")
if shuffle:
ds = ds.shuffle(buffer_size=ori_dataset_size // 20)
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="src_length", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="prev_opt", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="target", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="tgt_padding", operations=type_cast_op, num_parallel_workers=8)
@ -81,13 +81,11 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
ds = ds.rename(
input_columns=["src",
"src_padding",
"src_length",
"prev_opt",
"target",
"tgt_padding"],
output_columns=["source_eos_ids",
"source_eos_mask",
"source_eos_length",
"target_sos_ids",
"target_eos_ids",
"target_eos_mask"]
@ -97,26 +95,24 @@ def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
ds = de.TFRecordDataset(
input_files, schema_file,
columns_list=[
"src", "src_padding", "src_length"
"src", "src_padding"
],
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
shuffle=False, num_shards=rank_size, shard_id=rank_id,
shard_equal_rows=True, num_parallel_workers=8)
ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size: {ori_dataset_size}.")
if shuffle:
ds = ds.shuffle(buffer_size=ori_dataset_size // 20)
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
ds = ds.map(input_columns="src_length", operations=type_cast_op, num_parallel_workers=8)
ds = ds.rename(
input_columns=["src",
"src_padding",
"src_length"],
"src_padding"],
output_columns=["source_eos_ids",
"source_eos_mask",
"source_eos_length"]
"source_eos_mask"]
)
ds = ds.batch(batch_size, drop_remainder=drop_remainder)

@ -17,7 +17,6 @@
SCHEMA = {
"src": {"type": "int64", "shape": [-1]},
"src_padding": {"type": "int64", "shape": [-1]},
"src_length": {"type": "int64", "shape": [-1]},
"prev_opt": {"type": "int64", "shape": [-1]},
"target": {"type": "int64", "shape": [-1]},
"tgt_padding": {"type": "int64", "shape": [-1]},

@ -71,7 +71,7 @@ class GNMTEncoder(nn.Cell):
self.reverse_v2 = P.ReverseV2(axis=[0])
self.dropout = nn.Dropout(keep_prob=1.0 - config.hidden_dropout_prob)
def construct(self, inputs, source_len, attention_mask=None):
def construct(self, inputs):
"""Encoder."""
inputs = self.dropout(inputs)
# bidirectional layer, fwd_encoder_outputs: [T,N,D]

@ -109,8 +109,7 @@ class GNMT(nn.Cell):
self.beam_decoder.add_flags(loop_can_unroll=True)
self.shape = P.Shape()
def construct(self, source_ids, source_mask=None, source_len=None,
target_ids=None):
def construct(self, source_ids, source_mask=None, target_ids=None):
"""
Construct network.
@ -121,8 +120,6 @@ class GNMT(nn.Cell):
source_mask (Tensor): Source sentences padding mask with shape (N, T),
where 0 indicates padding position.
target_ids (Tensor): Target sentences with shape (N, T').
target_mask (Tensor): Target sentences padding mask with shape (N, T'),
where 0 indicates padding position.
Returns:
Tuple[Tensor], network outputs.
@ -133,7 +130,7 @@ class GNMT(nn.Cell):
# T, N, D
inputs = self.transpose(src_embeddings, self.transpose_orders)
# encoder. encoder_outputs: [T, N, D]
encoder_outputs = self.gnmt_encoder(inputs, source_len=source_len)
encoder_outputs = self.gnmt_encoder(inputs)
# decoder.
if self.is_training:

@ -66,13 +66,11 @@ class GNMTInferCell(nn.Cell):
def construct(self,
source_ids,
source_mask,
source_len):
source_mask):
"""Defines the computation performed."""
predicted_ids = self.network(source_ids,
source_mask,
source_len)
source_mask)
return predicted_ids
@ -143,21 +141,18 @@ def gnmt_infer(config, dataset):
[config.batch_size, 1]), mstype.int32)
source_mask_pad = Tensor(np.tile(np.array([[1, 1] + [0] * (config.seq_length - 2)]),
[config.batch_size, 1]), mstype.int32)
source_len_pad = Tensor(np.tile(np.array([[2]]), [config.batch_size, 1]), mstype.int32)
for batch in dataset.create_dict_iterator():
source_sentences.append(batch["source_eos_ids"].asnumpy())
source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
source_len = Tensor(batch["source_eos_length"], mstype.int32)
active_num = shape(source_ids)[0]
if active_num < config.batch_size:
source_ids = concat((source_ids, source_ids_pad[active_num:, :]))
source_mask = concat((source_mask, source_mask_pad[active_num:, :]))
source_len = concat((source_len, source_len_pad[active_num:, :]))
start_time = time.time()
predicted_ids = model.predict(source_ids, source_mask, source_len)
predicted_ids = model.predict(source_ids, source_mask)
print(f" | BatchIndex = {batch_index}, Batch size: {config.batch_size}, active_num={active_num}, "
f"Time cost: {time.time() - start_time}.")

@ -81,21 +81,19 @@ class GNMTTraining(nn.Cell):
self.gnmt = GNMT(config, is_training, use_one_hot_embeddings)
self.projection = PredLogProbs(config)
def construct(self, source_ids, source_mask, source_len, target_ids):
def construct(self, source_ids, source_mask, target_ids):
"""
Construct network.
Args:
source_ids (Tensor): Source sentence.
source_mask (Tensor): Source padding mask.
source_len (Tensor): Effective length of source sentence.
target_ids (Tensor): Target sentence.
Returns:
Tensor, prediction_scores.
"""
decoder_outputs = self.gnmt(source_ids, source_mask, source_len, target_ids)
decoder_outputs = self.gnmt(source_ids, source_mask, target_ids)
prediction_scores = self.projection(decoder_outputs)
return prediction_scores
@ -175,11 +173,10 @@ class GNMTNetworkWithLoss(nn.Cell):
def construct(self,
source_ids,
source_mask,
source_len,
target_ids,
label_ids,
label_weights):
prediction_scores = self.gnmt(source_ids, source_mask, source_len, target_ids)
prediction_scores = self.gnmt(source_ids, source_mask, target_ids)
total_loss = self.loss(prediction_scores, label_ids, label_weights)
return self.cast(total_loss, mstype.float32)
@ -265,7 +262,6 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
def construct(self,
source_eos_ids,
source_eos_mask,
source_eos_length,
target_sos_ids,
target_eos_ids,
target_eos_mask,
@ -276,7 +272,6 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
Args:
source_eos_ids (Tensor): Source sentence.
source_eos_mask (Tensor): Source padding mask.
source_eos_length (Tensor): Effective length of source sentence.
target_sos_ids (Tensor): Target sentence.
target_eos_ids (Tensor): Prediction sentence.
target_eos_mask (Tensor): Prediction padding mask.
@ -287,7 +282,6 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
"""
source_ids = source_eos_ids
source_mask = source_eos_mask
source_len = source_eos_length
target_ids = target_sos_ids
label_ids = target_eos_ids
label_weights = target_eos_mask
@ -295,7 +289,6 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
weights = self.weights
loss = self.network(source_ids,
source_mask,
source_len,
target_ids,
label_ids,
label_weights)
@ -309,7 +302,6 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell):
scaling_sens = sens
grads = self.grad(self.network, weights)(source_ids,
source_mask,
source_len,
target_ids,
label_ids,
label_weights,

@ -40,6 +40,8 @@ from src.utils.optimizer import Adam
parser = argparse.ArgumentParser(description='GNMT train entry point.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.")
parser.add_argument("--dataset_schema_train", type=str, required=True, help="dataset schema for train.")
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
device_id = os.getenv('DEVICE_ID', None)
if device_id is None:
@ -351,9 +353,9 @@ if __name__ == '__main__':
args, _ = parser.parse_known_args()
_check_args(args.config)
_config = get_config(args.config)
_config.dataset_schema = args.dataset_schema_train
_config.pre_train_dataset = args.pre_train_dataset
set_seed(_config.random_seed)
if _rank_size is not None and int(_rank_size) > 1:
train_parallel(_config)
else:

@ -0,0 +1,109 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Train and eval api."""
import os
import argparse
import pickle
import datetime
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from config import GNMTConfig
from train import train_parallel
from src.gnmt_model import infer
from src.gnmt_model.bleu_calculate import bleu_calculate
from src.dataset.tokenizer import Tokenizer
parser = argparse.ArgumentParser(description='GNMT train and eval.')
# train
parser.add_argument("--config_train", type=str, required=True,
help="model config json file path.")
parser.add_argument("--dataset_schema_train", type=str, required=True,
help="dataset schema for train.")
parser.add_argument("--pre_train_dataset", type=str, required=True,
help="pre-train dataset address.")
# eval
parser.add_argument("--config_test", type=str, required=True,
help="model config json file path.")
parser.add_argument("--dataset_schema_test", type=str, required=True,
help="dataset schema for evaluation.")
parser.add_argument("--test_dataset", type=str, required=True,
help="test dataset address.")
parser.add_argument("--existed_ckpt", type=str, required=True,
help="existed checkpoint address.")
parser.add_argument("--vocab", type=str, required=True,
help="Vocabulary to use.")
parser.add_argument("--bpe_codes", type=str, required=True,
help="bpe codes to use.")
parser.add_argument("--test_tgt", type=str, required=True,
default=None,
help="data file of the test target")
parser.add_argument("--output", type=str, required=False,
default="./output.npz",
help="result file path.")
def get_config(config):
config = GNMTConfig.from_json_file(config)
config.compute_type = mstype.float16
config.dtype = mstype.float32
return config
def _check_args(config):
if not os.path.exists(config):
raise FileNotFoundError("`config` is not existed.")
if not isinstance(config, str):
raise ValueError("`config` must be type of str.")
if __name__ == '__main__':
start_time = datetime.datetime.now()
_rank_size = os.getenv('RANK_SIZE')
args, _ = parser.parse_known_args()
# train
_check_args(args.config_train)
_config_train = get_config(args.config_train)
_config_train.dataset_schema = args.dataset_schema_train
_config_train.pre_train_dataset = args.pre_train_dataset
set_seed(_config_train.random_seed)
assert _rank_size is not None and int(_rank_size) > 1
if _rank_size is not None and int(_rank_size) > 1:
train_parallel(_config_train)
# eval
_check_args(args.config_test)
_config_test = get_config(args.config_test)
_config_test.dataset_schema = args.dataset_schema_test
_config_test.test_dataset = args.test_dataset
_config_test.existed_ckpt = args.existed_ckpt
result = infer(_config_test)
with open(args.output, "wb") as f:
pickle.dump(result, f, 1)
result_npy_addr = args.output
vocab = args.vocab
bpe_codes = args.bpe_codes
test_tgt = args.test_tgt
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
print(f"BLEU scores is :{scores}")
end_time = datetime.datetime.now()
cost_time = (end_time - start_time).seconds
print(f"Cost time is {cost_time}s.")
assert scores >= 23.8
assert cost_time < 10800.0
print("----done!----")

@ -0,0 +1,86 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "sh run_distributed_train_ascend.sh \
GNMT_ADDR RANK_TABLE_ADDR \
DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET \
DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
echo "for example:"
echo "sh run_distributed_train_ascend.sh \
/home/workspace/gnmt_v2 \
/home/workspace/rank_table_8p.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001 \
/home/workspace/dataset_menu/newstest2014.en.json \
/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001 \
/home/workspace/gnmt_v2/gnmt-6_3452.ckpt \
/home/workspace/wmt16_de_en/vocab.bpe.32000 \
/home/workspace/wmt16_de_en/bpe.32000 \
/home/workspace/wmt16_de_en/newstest2014.de"
echo "It is better to use absolute path."
echo "=============================================================================================================="
GNMT_ADDR=$1
RANK_TABLE_ADDR=$2
# train dataset addr
DATASET_SCHEMA_TRAIN=$3
PRE_TRAIN_DATASET=$4
# eval dataset addr
DATASET_SCHEMA_TEST=$5
TEST_DATASET=$6
EXISTED_CKPT_PATH=$7
VOCAB_ADDR=$8
BPE_CODE_ADDR=$9
TEST_TARGET=${10}
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_TABLE_FILE=$RANK_TABLE_ADDR
export MINDSPORE_HCCL_CONFIG_PATH=$RANK_TABLE_ADDR
echo $RANK_TABLE_FILE
export RANK_SIZE=8
export GLOG_v=2
for((i=0;i<=7;i++));
do
rm -rf ${current_exec_path}/device$i
mkdir ${current_exec_path}/device$i
cd ${current_exec_path}/device$i || exit
cp ${current_exec_path}/*.py .
cp ${GNMT_ADDR}/*.py .
cp -r ${GNMT_ADDR}/src .
cp -r ${GNMT_ADDR}/config .
export RANK_ID=$i
export DEVICE_ID=$i
python test_gnmt_v2.py \
--config_train=${GNMT_ADDR}/config/config.json \
--dataset_schema_train=$DATASET_SCHEMA_TRAIN \
--pre_train_dataset=$PRE_TRAIN_DATASET \
--config_test=${GNMT_ADDR}/config/config_test.json \
--dataset_schema_test=$DATASET_SCHEMA_TEST \
--test_dataset=$TEST_DATASET \
--existed_ckpt=$EXISTED_CKPT_PATH \
--vocab=$VOCAB_ADDR \
--bpe_codes=$BPE_CODE_ADDR \
--test_tgt=$TEST_TARGET > log_gnmt_network${i}.log 2>&1 &
cd ${current_exec_path} || exit
done
cd ${current_exec_path} || exit
Loading…
Cancel
Save