delete schema_json

pull/10412/head
gaojing 4 years ago
parent f1b6b7ad09
commit 1f3d9ddff8

@ -58,8 +58,8 @@ Note that you can run the scripts based on the dataset mentioned in original pap
```txt ```txt
numpy numpy
sacrebleu==1.2.10 sacrebleu==1.4.14
sacremoses==0.0.19 sacremoses==0.0.35
subword_nmt==0.3.7 subword_nmt==0.3.7
``` ```
@ -77,15 +77,15 @@ After dataset preparation, you can start training and evaluation as follows:
```bash ```bash
# run training example # run training example
cd ./scripts cd ./scripts
sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET
# run distributed training example # run distributed training example
cd ./scripts cd ./scripts
sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET
# run evaluation example # run evaluation example
cd ./scripts cd ./scripts
sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
``` ```
@ -187,7 +187,7 @@ For more configuration details, please refer the script `config/config.py` file.
## Training Process ## Training Process
For a pre-trained model, configure the following options in the `scripts/run_standalone_train_ascend.json` file: For a pre-trained model, configure the following options in the `config/config.json` file:
- Select an optimizer ('momentum/adam/lamb' is available). - Select an optimizer ('momentum/adam/lamb' is available).
- Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file. - Specify `ckpt_prefix` and `ckpt_path` in `checkpoint_path` to save the model file.
@ -198,17 +198,17 @@ Start task training on a single device and run the shell script `scripts/run_sta
```bash ```bash
cd ./scripts cd ./scripts
sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET
``` ```
In this script, the `DATASET_SCHEMA_TRAIN` and `PRE_TRAIN_DATASET` are the dataset schema and dataset address. In this script, the `PRE_TRAIN_DATASET` is the dataset address.
Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model. Run `scripts/run_distributed_train_ascend.sh` for distributed training of GNMTv2 model.
Task training on multiple devices and run the following command in bash to be executed in `scripts/`.: Task training on multiple devices and run the following command in bash to be executed in `scripts/`.:
```bash ```bash
cd ./scripts cd ./scripts
sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET
``` ```
Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running. Note: the `RANK_TABLE_ADDR` is the hccl_json file assigned when distributed training is running.
@ -224,11 +224,11 @@ Run the shell script `scripts/run_standalone_eval_ascend.sh` to process the outp
```bash ```bash
cd ./scripts cd ./scripts
sh run_standalone_eval_ascend.sh sh run_standalone_eval_ascend.sh
sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET
``` ```
The `DATASET_SCHEMA_TEST` and the `TEST_DATASET` are the schema and address of inference dataset respectively, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process. The `TEST_DATASET` is the address of inference dataset, and `EXISTED_CKPT_PATH` is the path of the model file generated during training process.
The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers. The `VOCAB_ADDR` is the vocabulary address, `BPE_CODE_ADDR` is the bpe code address and the `TEST_TARGET` are the path of answers.
# [Model Description](#contents) # [Model Description](#contents)

@ -3,7 +3,6 @@
"random_seed": 50, "random_seed": 50,
"epochs": 6, "epochs": 6,
"batch_size": 128, "batch_size": 128,
"dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json",
"pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord", "pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord",
"fine_tune_dataset": null, "fine_tune_dataset": null,
"valid_dataset": null, "valid_dataset": null,

@ -67,7 +67,6 @@ class GNMTConfig:
random_seed (int): Random seed, it can be changed. random_seed (int): Random seed, it can be changed.
epochs (int): Epoch number. epochs (int): Epoch number.
batch_size (int): Batch size of input dataset. batch_size (int): Batch size of input dataset.
dataset_schema (str): Path of dataset schema file.
pre_train_dataset (str): Path of pre-training dataset file or folder. pre_train_dataset (str): Path of pre-training dataset file or folder.
fine_tune_dataset (str): Path of fine-tune dataset file or folder. fine_tune_dataset (str): Path of fine-tune dataset file or folder.
test_dataset (str): Path of test dataset file or folder. test_dataset (str): Path of test dataset file or folder.
@ -126,7 +125,6 @@ class GNMTConfig:
def __init__(self, def __init__(self,
random_seed=50, random_seed=50,
epochs=6, batch_size=128, epochs=6, batch_size=128,
dataset_schema: str = None,
pre_train_dataset: str = None, pre_train_dataset: str = None,
fine_tune_dataset: str = None, fine_tune_dataset: str = None,
test_dataset: str = None, test_dataset: str = None,
@ -157,7 +155,6 @@ class GNMTConfig:
self.save_graphs = save_graphs self.save_graphs = save_graphs
self.random_seed = random_seed self.random_seed = random_seed
self.dataset_schema = dataset_schema
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str] self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str] self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
self.valid_dataset = get_source_list(valid_dataset) # type: List[str] self.valid_dataset = get_source_list(valid_dataset) # type: List[str]

@ -3,7 +3,6 @@
"random_seed": 50, "random_seed": 50,
"epochs": 6, "epochs": 6,
"batch_size": 128, "batch_size": 128,
"dataset_schema": "/home/workspace/dataset_menu/newstest2014.en.json",
"pre_train_dataset": null, "pre_train_dataset": null,
"fine_tune_dataset": null, "fine_tune_dataset": null,
"test_dataset": "/home/workspace/dataset_menu/newstest2014.en.mindrecord", "test_dataset": "/home/workspace/dataset_menu/newstest2014.en.mindrecord",

@ -27,8 +27,6 @@ from src.dataset.tokenizer import Tokenizer
parser = argparse.ArgumentParser(description='gnmt') parser = argparse.ArgumentParser(description='gnmt')
parser.add_argument("--config", type=str, required=True, parser.add_argument("--config", type=str, required=True,
help="model config json file path.") help="model config json file path.")
parser.add_argument("--dataset_schema_test", type=str, required=True,
help="dataset schema for evaluation.")
parser.add_argument("--test_dataset", type=str, required=True, parser.add_argument("--test_dataset", type=str, required=True,
help="test dataset address.") help="test dataset address.")
parser.add_argument("--existed_ckpt", type=str, required=True, parser.add_argument("--existed_ckpt", type=str, required=True,
@ -63,7 +61,6 @@ if __name__ == '__main__':
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
_check_args(args.config) _check_args(args.config)
_config = get_config(args.config) _config = get_config(args.config)
_config.dataset_schema = args.dataset_schema_test
_config.test_dataset = args.test_dataset _config.test_dataset = args.test_dataset
_config.existed_ckpt = args.existed_ckpt _config.existed_ckpt = args.existed_ckpt
result = infer(_config) result = infer(_config)

@ -0,0 +1,102 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export checkpoint file into air models"""
import argparse
import numpy as np
from mindspore import Tensor, context, Parameter
from mindspore.common import dtype as mstype
from mindspore.train.serialization import export
from config import GNMTConfig
from src.gnmt_model.gnmt import GNMT
from src.gnmt_model.gnmt_for_infer import GNMTInferCell
from src.utils import zero_weight
from src.utils.load_weights import load_infer_weights
parser = argparse.ArgumentParser(description="gnmt_v2 export")
parser.add_argument("--file_name", type=str, default="gnmt_v2", help="output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
parser.add_argument('--infer_config', type=str, required=True, help='gnmt_v2 config file')
parser.add_argument("--existed_ckpt", type=str, required=True, help="existed checkpoint address.")
parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file')
parser.add_argument("--bpe_codes", type=str, required=True, help="bpe codes to use.")
args = parser.parse_args()
context.set_context(
mode=context.GRAPH_MODE,
save_graphs=False,
device_target="Ascend",
reserve_class_name_in_scope=False)
def get_config(config_file):
tfm_config = GNMTConfig.from_json_file(config_file)
tfm_config.compute_type = mstype.float16
tfm_config.dtype = mstype.float32
return tfm_config
if __name__ == '__main__':
config = get_config(args.infer_config)
config.existed_ckpt = args.existed_ckpt
vocab = args.vocab_file
bpe_codes = args.bpe_codes
tfm_model = GNMT(config=config,
is_training=False,
use_one_hot_embeddings=False)
params = tfm_model.trainable_params()
weights = load_infer_weights(config)
for param in params:
value = param.data
name = param.name
if name not in weights:
raise ValueError(f"{name} is not found in weights.")
with open("weight_after_deal.txt", "a+") as f:
weights_name = name
f.write(weights_name)
f.write("\n")
if isinstance(value, Tensor):
print(name, value.asnumpy().shape)
if weights_name in weights:
assert weights_name in weights
if isinstance(weights[weights_name], Parameter):
if param.data.dtype == "Float32":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
elif param.data.dtype == "Float16":
param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
elif isinstance(weights[weights_name], Tensor):
param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
elif isinstance(weights[weights_name], np.ndarray):
param.set_data(Tensor(weights[weights_name], config.dtype))
else:
param.set_data(weights[weights_name])
else:
print("weight not found in checkpoint: " + weights_name)
param.set_data(zero_weight(value.asnumpy().shape))
f.close()
print(" | Load weights successfully.")
tfm_infer = GNMTInferCell(tfm_model)
tfm_infer.set_train(False)
source_ids = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32))
source_mask = Tensor(np.ones((config.batch_size, config.seq_length)).astype(np.int32))
export(tfm_infer, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format)

@ -1,6 +1,4 @@
nltk
jieba
numpy numpy
subword-nmt==0.3.7 subword-nmt==0.3.7
sacrebleu==1.2.10 sacrebleu==1.4.14
sacremoses==0.0.19 sacremoses==0.0.35

@ -16,18 +16,16 @@
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the script as: " echo "Please run the script as: "
echo "sh run_distributed_train_ascend.sh RANK_TABLE_ADDR DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET" echo "sh run_distributed_train_ascend.sh RANK_TABLE_ADDR PRE_TRAIN_DATASET"
echo "for example:" echo "for example:"
echo "sh run_distributed_train_ascend.sh \ echo "sh run_distributed_train_ascend.sh \
/home/workspace/rank_table_8p.json \ /home/workspace/rank_table_8p.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord" /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord"
echo "It is better to use absolute path." echo "It is better to use absolute path."
echo "==============================================================================================================" echo "=============================================================================================================="
RANK_TABLE_ADDR=$1 RANK_TABLE_ADDR=$1
DATASET_SCHEMA_TRAIN=$2 PRE_TRAIN_DATASET=$2
PRE_TRAIN_DATASET=$3
current_exec_path=$(pwd) current_exec_path=$(pwd)
echo ${current_exec_path} echo ${current_exec_path}
@ -51,7 +49,6 @@ do
export DEVICE_ID=$i export DEVICE_ID=$i
python ../../train.py \ python ../../train.py \
--config=${current_exec_path}/device${i}/config/config.json \ --config=${current_exec_path}/device${i}/config/config.json \
--dataset_schema_train=$DATASET_SCHEMA_TRAIN \
--pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 & --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network${i}.log 2>&1 &
cd ${current_exec_path} || exit cd ${current_exec_path} || exit
done done

@ -16,11 +16,10 @@
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the script as: " echo "Please run the script as: "
echo "sh run_standalone_eval_ascend.sh DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \ echo "sh run_standalone_eval_ascend.sh TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET" VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
echo "for example:" echo "for example:"
echo "sh run_standalone_eval_ascend.sh \ echo "sh run_standalone_eval_ascend.sh \
/home/workspace/dataset_menu/newstest2014.en.json \
/home/workspace/dataset_menu/newstest2014.en.mindrecord \ /home/workspace/dataset_menu/newstest2014.en.mindrecord \
/home/workspace/gnmt_v2/gnmt-6_3452.ckpt \ /home/workspace/gnmt_v2/gnmt-6_3452.ckpt \
/home/workspace/wmt16_de_en/vocab.bpe.32000 \ /home/workspace/wmt16_de_en/vocab.bpe.32000 \
@ -29,19 +28,16 @@ echo "sh run_standalone_eval_ascend.sh \
echo "It is better to use absolute path." echo "It is better to use absolute path."
echo "==============================================================================================================" echo "=============================================================================================================="
DATASET_SCHEMA_TEST=$1 TEST_DATASET=$1
TEST_DATASET=$2 EXISTED_CKPT_PATH=$2
EXISTED_CKPT_PATH=$3 VOCAB_ADDR=$3
VOCAB_ADDR=$4 BPE_CODE_ADDR=$4
BPE_CODE_ADDR=$5 TEST_TARGET=$5
TEST_TARGET=$6
current_exec_path=$(pwd) current_exec_path=$(pwd)
echo ${current_exec_path} echo ${current_exec_path}
export DEVICE_NUM=1
export RANK_ID=0
export RANK_SIZE=1
export GLOG_v=2 export GLOG_v=2
if [ -d "eval" ]; if [ -d "eval" ];
@ -57,7 +53,6 @@ echo "start for evaluation"
env > env.log env > env.log
python eval.py \ python eval.py \
--config=${current_exec_path}/eval/config/config_test.json \ --config=${current_exec_path}/eval/config/config_test.json \
--dataset_schema_test=$DATASET_SCHEMA_TEST \
--test_dataset=$TEST_DATASET \ --test_dataset=$TEST_DATASET \
--existed_ckpt=$EXISTED_CKPT_PATH \ --existed_ckpt=$EXISTED_CKPT_PATH \
--vocab=$VOCAB_ADDR \ --vocab=$VOCAB_ADDR \

@ -16,21 +16,17 @@
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the script as: " echo "Please run the script as: "
echo "sh run_standalone_train_ascend.sh DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET" echo "sh run_standalone_train_ascend.sh PRE_TRAIN_DATASET"
echo "for example:" echo "for example:"
echo "sh run_standalone_train_ascend.sh \ echo "sh run_standalone_train_ascend.sh \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord" /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord"
echo "It is better to use absolute path." echo "It is better to use absolute path."
echo "==============================================================================================================" echo "=============================================================================================================="
DATASET_SCHEMA_TRAIN=$1 PRE_TRAIN_DATASET=$1
PRE_TRAIN_DATASET=$2
export DEVICE_NUM=1
export RANK_ID=0
export RANK_SIZE=1
export GLOG_v=2 export GLOG_v=2
current_exec_path=$(pwd) current_exec_path=$(pwd)
echo ${current_exec_path} echo ${current_exec_path}
if [ -d "train" ]; if [ -d "train" ];
@ -46,6 +42,5 @@ echo "start for training"
env > env.log env > env.log
python train.py \ python train.py \
--config=${current_exec_path}/train/config/config.json \ --config=${current_exec_path}/train/config/config.json \
--dataset_schema_train=$DATASET_SCHEMA_TRAIN \
--pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network.log 2>&1 & --pre_train_dataset=$PRE_TRAIN_DATASET > log_gnmt_network.log 2>&1 &
cd .. cd ..

@ -13,13 +13,12 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Dataset loader to feed into model.""" """Dataset loader to feed into model."""
import os
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as deC import mindspore.dataset.transforms.c_transforms as deC
def _load_dataset(input_files, schema_file, batch_size, sink_mode=False, def _load_dataset(input_files, batch_size, sink_mode=False,
rank_size=1, rank_id=0, shuffle=True, drop_remainder=True, rank_size=1, rank_id=0, shuffle=True, drop_remainder=True,
is_translate=False): is_translate=False):
""" """
@ -27,7 +26,6 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False,
Args: Args:
input_files (list): Data files. input_files (list): Data files.
schema_file (str): Schema file path.
batch_size (int): Batch size. batch_size (int): Batch size.
sink_mode (bool): Whether enable sink mode. sink_mode (bool): Whether enable sink mode.
rank_size (int): Rank size. rank_size (int): Rank size.
@ -42,12 +40,6 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False,
if not input_files: if not input_files:
raise FileNotFoundError("Require at least one dataset.") raise FileNotFoundError("Require at least one dataset.")
if not (schema_file and
os.path.exists(schema_file)
and os.path.isfile(schema_file)
and os.path.basename(schema_file).endswith(".json")):
raise FileNotFoundError("`dataset_schema` must be a existed json file.")
if not isinstance(sink_mode, bool): if not isinstance(sink_mode, bool):
raise ValueError("`sink` must be type of bool.") raise ValueError("`sink` must be type of bool.")
@ -116,14 +108,13 @@ def _load_dataset(input_files, schema_file, batch_size, sink_mode=False,
return ds return ds
def load_dataset(data_files: list, schema: str, batch_size: int, sink_mode: bool, def load_dataset(data_files: list, batch_size: int, sink_mode: bool,
rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False): rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False):
""" """
Load dataset. Load dataset.
Args: Args:
data_files (list): Data files. data_files (list): Data files.
schema (str): Schema file path.
batch_size (int): Batch size. batch_size (int): Batch size.
sink_mode (bool): Whether enable sink mode. sink_mode (bool): Whether enable sink mode.
rank_size (int): Rank size. rank_size (int): Rank size.
@ -133,5 +124,5 @@ def load_dataset(data_files: list, schema: str, batch_size: int, sink_mode: bool
Returns: Returns:
Dataset, dataset instance. Dataset, dataset instance.
""" """
return _load_dataset(data_files, schema, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle, return _load_dataset(data_files, batch_size, sink_mode, rank_size, rank_id, shuffle=shuffle,
drop_remainder=drop_remainder, is_translate=is_translate) drop_remainder=drop_remainder, is_translate=is_translate)

@ -38,7 +38,7 @@ class BahdanauAttention(nn.Cell):
initializer_range: range for uniform initializer parameters. initializer_range: range for uniform initializer parameters.
Returns: Returns:
Tensor, shape (N, T, D). Tensor, shape (t_q_length, N, D).
""" """
def __init__(self, def __init__(self,
@ -93,108 +93,107 @@ class BahdanauAttention(nn.Cell):
Construct attention block. Construct attention block.
Args: Args:
query (Tensor): Shape (t_q, N, D). query (Tensor): Shape (t_q_length, N, D).
keys (Tensor): Shape (t_k, N, D). keys (Tensor): Shape (t_k_length, N, D).
attention_mask: Shape(N, t_k). attention_mask: Shape(N, t_k_length).
Returns: Returns:
Tensor, shape (N, t_q, D). Tensor, shape (t_q_length, N, D).
""" """
# (t_k, N, D) -> (N, t_k, D). # (t_k_length, N, D) -> (N, t_k_length, D).
keys = self.transpose(keys, self.transpose_orders) keys = self.transpose(keys, self.transpose_orders)
# (t_q, N, D) -> (N, t_q, D). # (t_q_length, N, D) -> (N, t_q_length, D).
query = self.transpose(query, self.transpose_orders) query_trans = self.transpose(query, self.transpose_orders)
query_shape = self.shape_op(query) query_shape = self.shape_op(query_trans)
b = query_shape[0] batch_size = query_shape[0]
t_q = query_shape[1] t_q_length = query_shape[1]
t_k = self.shape_op(keys)[1] t_k_length = self.shape_op(keys)[1]
# (N, t_q, D) # (N, t_q_length, D)
query = self.reshape(query, (b * t_q, self.query_size)) query_trans = self.reshape(query_trans, (batch_size * t_q_length, self.query_size))
if self.is_training: if self.is_training:
query = self.cast(query, mstype.float16) query_trans = self.cast(query_trans, mstype.float16)
processed_query = self.linear_q(query) processed_query = self.linear_q(query_trans)
if self.is_trining: if self.is_trining:
processed_query = self.cast(processed_query, mstype.float32) processed_query = self.cast(processed_query, mstype.float32)
processed_query = self.reshape(processed_query, (b, t_q, self.num_units)) processed_query = self.reshape(processed_query, (batch_size, t_q_length, self.num_units))
# (N, t_k, D) # (N, t_k_length, D)
keys = self.reshape(keys, (b * t_k, self.key_size)) keys = self.reshape(keys, (batch_size * t_k_length, self.key_size))
if self.is_training: if self.is_training:
keys = self.cast(keys, mstype.float16) keys = self.cast(keys, mstype.float16)
processed_key = self.linear_k(keys) processed_key = self.linear_k(keys)
if self.is_trining: if self.is_trining:
processed_key = self.cast(processed_key, mstype.float32) processed_key = self.cast(processed_key, mstype.float32)
processed_key = self.reshape(processed_key, (b, t_k, self.num_units)) processed_key = self.reshape(processed_key, (batch_size, t_k_length, self.num_units))
# scores: (N T_q T_k) # scores: (N, t_q_length, t_k_length)
scores = self.calc_score(processed_query, processed_key) scores = self.obtain_score(processed_query, processed_key)
# attention_mask: (N, T_k) # attention_mask: (N, t_k_length)
mask = attention_mask mask = attention_mask
# [N 1]
if mask is not None: if mask is not None:
mask = 1.0 - mask mask = 1.0 - mask
mask = self.tile(self.expand(mask, 1), (1, t_q, 1)) mask = self.tile(self.expand(mask, 1), (1, t_q_length, 1))
scores += mask * (-INF) scores += mask * (-INF)
# [b, t_q, t_k] # [batch_size, t_q_length, t_k_length]
scores_normalized = self.softmax(scores) scores_softmax = self.softmax(scores)
keys = self.reshape(keys, (b, t_k, self.key_size)) keys = self.reshape(keys, (batch_size, t_k_length, self.key_size))
if self.is_training: if self.is_training:
keys = self.cast(keys, mstype.float16) keys = self.cast(keys, mstype.float16)
scores_normalized_fp16 = self.cast(scores_normalized, mstype.float16) scores_softmax_fp16 = self.cast(scores_softmax, mstype.float16)
else: else:
scores_normalized_fp16 = scores_normalized scores_softmax_fp16 = scores_softmax
# (b, t_q, n) # (b, t_q_length, D)
context_attention = self.batchMatmul(scores_normalized_fp16, keys) context_attention = self.batchMatmul(scores_softmax_fp16, keys)
# [t_q,b,D] # [t_q_length, b, D]
context_attention = self.transpose(context_attention, self.transpose_orders) context_attention = self.transpose(context_attention, self.transpose_orders)
if self.is_training: if self.is_training:
context_attention = self.cast(context_attention, mstype.float32) context_attention = self.cast(context_attention, mstype.float32)
return context_attention, scores_normalized return context_attention, scores_softmax
def calc_score(self, att_query, att_keys): def obtain_score(self, attention_q, attention_k):
""" """
Calculate Bahdanau score Calculate Bahdanau score
Args: Args:
att_query: (N, T_q, D). attention_q: (batch_size, t_q_length, D).
att_keys: (N, T_k, D). attention_k: (batch_size, t_k_length, D).
returns: returns:
scores: (N, T_q, T_k). scores: (batch_size, t_q_length, t_k_length).
""" """
b, t_k, n = self.shape_op(att_keys) batch_size, t_k_length, D = self.shape_op(attention_k)
t_q = self.shape_op(att_query)[1] t_q_length = self.shape_op(attention_q)[1]
# (b, t_q, t_k, n) # (batch_size, t_q_length, t_k_length, n)
att_query = self.tile(self.expand(att_query, 2), (1, 1, t_k, 1)) attention_q = self.tile(self.expand(attention_q, 2), (1, 1, t_k_length, 1))
att_keys = self.tile(self.expand(att_keys, 1), (1, t_q, 1, 1)) attention_k = self.tile(self.expand(attention_k, 1), (1, t_q_length, 1, 1))
# (b, t_q, t_k, n) # (batch_size, t_q_length, t_k_length, n)
sum_qk = att_query + att_keys sum_qk_add = attention_q + attention_k
if self.normalize: if self.normalize:
# (b, t_q, t_k, n) # (batch_size, t_q_length, t_k_length, n)
sum_qk = sum_qk + self.normalize_bias sum_qk_add = sum_qk_add + self.normalize_bias
linear_att = self.linear_att / self.norm(self.linear_att) linear_att_norm = self.linear_att / self.norm(self.linear_att)
linear_att = self.cast(linear_att, mstype.float32) linear_att_norm = self.cast(linear_att_norm, mstype.float32)
linear_att = self.mul(linear_att, self.normalize_scalar) linear_att_norm = self.mul(linear_att_norm, self.normalize_scalar)
else: else:
linear_att = self.linear_att linear_att_norm = self.linear_att
linear_att = self.expand(linear_att, -1) linear_att_norm = self.expand(linear_att_norm, -1)
sum_qk = self.reshape(sum_qk, (-1, n)) sum_qk_add = self.reshape(sum_qk_add, (-1, D))
tanh_sum_qk = self.tanh(sum_qk) tanh_sum_qk = self.tanh(sum_qk_add)
if self.is_training: if self.is_training:
linear_att = self.cast(linear_att, mstype.float16) linear_att_norm = self.cast(linear_att_norm, mstype.float16)
tanh_sum_qk = self.cast(tanh_sum_qk, mstype.float16) tanh_sum_qk = self.cast(tanh_sum_qk, mstype.float16)
out = self.matmul(tanh_sum_qk, linear_att) scores_out = self.matmul(tanh_sum_qk, linear_att_norm)
# (b, t_q, t_k) # (N, t_q_length, t_k_length)
out = self.reshape(out, (b, t_q, t_k)) scores_out = self.reshape(scores_out, (batch_size, t_q_length, t_k_length))
if self.is_training: if self.is_training:
out = self.cast(out, mstype.float32) scores_out = self.cast(scores_out, mstype.float32)
return out return scores_out

@ -214,9 +214,8 @@ class BeamSearchDecoder(nn.Cell):
self.concat = P.Concat(axis=-1) self.concat = P.Concat(axis=-1)
self.gather_nd = P.GatherNd() self.gather_nd = P.GatherNd()
self.start = Tensor(0, dtype=mstype.int32)
self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32) self.start_ids = Tensor(np.full([batch_size * beam_width, 1], sos_id), mstype.int32)
self.init_seq = Tensor(np.full([batch_size, beam_width, self.max_decode_length], sos_id), mstype.int32) self.init_seq = Tensor(np.full([batch_size, beam_width, 1], sos_id), mstype.int32)
init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1]) init_scores = np.tile(np.array([[0.] + [-INF] * (beam_width - 1)]), [batch_size, 1])
self.init_scores = Tensor(init_scores, mstype.float32) self.init_scores = Tensor(init_scores, mstype.float32)
@ -260,7 +259,7 @@ class BeamSearchDecoder(nn.Cell):
self.sub = P.Sub() self.sub = P.Sub()
def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs, def one_step(self, cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_length, idx=None, decoder_hidden_state=None, accu_attn_scores=None, state_seq, state_length, decoder_hidden_state=None, accu_attn_scores=None,
state_finished=None): state_finished=None):
""" """
Beam search one_step output. Beam search one_step output.
@ -270,7 +269,7 @@ class BeamSearchDecoder(nn.Cell):
enc_states (Tensor): with shape (batch_size * beam_width, T, D). enc_states (Tensor): with shape (batch_size * beam_width, T, D).
enc_attention_mask (Tensor): with shape (batch_size * beam_width, T). enc_attention_mask (Tensor): with shape (batch_size * beam_width, T).
state_log_probs (Tensor): with shape (batch_size, beam_width). state_log_probs (Tensor): with shape (batch_size, beam_width).
state_seq (Tensor): with shape (batch_size, beam_width, max_decoder_length). state_seq (Tensor): with shape (batch_size, beam_width, m).
state_length (Tensor): with shape (batch_size, beam_width). state_length (Tensor): with shape (batch_size, beam_width).
idx (Tensor): with shape (). idx (Tensor): with shape ().
decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D). decoder_hidden_state (Tensor): with shape (decoder_layer_num, 2, batch_size * beam_width, D).
@ -360,10 +359,7 @@ class BeamSearchDecoder(nn.Cell):
self.hidden_size)) self.hidden_size))
# update state_seq # update state_seq
state_seq_new = self.cast(seq, mstype.float32) state_seq = self.concat((seq, self.expand(word_indices, -1)))
word_indices_fp32 = self.cast(word_indices, mstype.float32)
state_seq_new[:, :, idx] = word_indices_fp32
state_seq = self.cast(state_seq_new, mstype.int32)
cur_input_ids = self.reshape(word_indices, (-1, 1)) cur_input_ids = self.reshape(word_indices, (-1, 1))
state_log_probs = topk_scores state_log_probs = topk_scores
@ -392,15 +388,11 @@ class BeamSearchDecoder(nn.Cell):
decoder_hidden_state = self.decoder_hidden_state decoder_hidden_state = self.decoder_hidden_state
accu_attn_scores = self.accu_attn_scores accu_attn_scores = self.accu_attn_scores
idx = self.start + 1 for _ in range(self.max_decode_length + 1):
ends = self.start + self.max_decode_length + 1
while idx < ends:
cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \ cur_input_ids, state_log_probs, state_seq, state_length, decoder_hidden_state, accu_attn_scores, \
state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs, state_finished = self.one_step(cur_input_ids, enc_states, enc_attention_mask, state_log_probs,
state_seq, state_length, idx, decoder_hidden_state, accu_attn_scores, state_seq, state_length, decoder_hidden_state, accu_attn_scores,
state_finished) state_finished)
idx = idx + 1
# add length penalty scores # add length penalty scores
penalty_len = self.length_penalty(state_length) penalty_len = self.length_penalty(state_length)
# return penalty_len # return penalty_len
@ -416,6 +408,6 @@ class BeamSearchDecoder(nn.Cell):
gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1))) gather_indices = self.concat((self.expand(self.batch_ids, -1), self.expand(top_beam_indices, -1)))
# sort sequence and attention scores # sort sequence and attention scores
predicted_ids = self.gather_nd(state_seq, gather_indices) predicted_ids = self.gather_nd(state_seq, gather_indices)
predicted_ids = predicted_ids[:, 0:1, :self.max_decode_length] predicted_ids = predicted_ids[:, 0:1, 1:(self.max_decode_length + 1)]
return predicted_ids return predicted_ids

@ -187,7 +187,6 @@ def infer(config):
list, result with list, result with
""" """
eval_dataset = load_dataset(data_files=config.test_dataset, eval_dataset = load_dataset(data_files=config.test_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
drop_remainder=False, drop_remainder=False,

@ -40,7 +40,6 @@ from src.utils.optimizer import Adam
parser = argparse.ArgumentParser(description='GNMT train entry point.') parser = argparse.ArgumentParser(description='GNMT train entry point.')
parser.add_argument("--config", type=str, required=True, help="model config json file path.") parser.add_argument("--config", type=str, required=True, help="model config json file path.")
parser.add_argument("--dataset_schema_train", type=str, required=True, help="dataset schema for train.")
parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.") parser.add_argument("--pre_train_dataset", type=str, required=True, help="pre-train dataset address.")
device_id = os.getenv('DEVICE_ID', None) device_id = os.getenv('DEVICE_ID', None)
@ -273,21 +272,20 @@ def train_parallel(config: GNMTConfig):
pre_train_dataset = load_dataset( pre_train_dataset = load_dataset(
data_files=config.pre_train_dataset, data_files=config.pre_train_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank() rank_id=MultiAscend.get_rank()
) if config.pre_train_dataset else None ) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset( fine_tune_dataset = load_dataset(
data_files=config.fine_tune_dataset, schema=config.dataset_schema, data_files=config.fine_tune_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
rank_id=MultiAscend.get_rank() rank_id=MultiAscend.get_rank()
) if config.fine_tune_dataset else None ) if config.fine_tune_dataset else None
test_dataset = load_dataset( test_dataset = load_dataset(
data_files=config.test_dataset, schema=config.dataset_schema, data_files=config.test_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode, sink_mode=config.dataset_sink_mode,
rank_size=MultiAscend.get_group_size(), rank_size=MultiAscend.get_group_size(),
@ -310,15 +308,12 @@ def train_single(config: GNMTConfig):
print(" | Starting training on single device.") print(" | Starting training on single device.")
pre_train_dataset = load_dataset(data_files=config.pre_train_dataset, pre_train_dataset = load_dataset(data_files=config.pre_train_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None sink_mode=config.dataset_sink_mode) if config.pre_train_dataset else None
fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset, fine_tune_dataset = load_dataset(data_files=config.fine_tune_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None sink_mode=config.dataset_sink_mode) if config.fine_tune_dataset else None
test_dataset = load_dataset(data_files=config.test_dataset, test_dataset = load_dataset(data_files=config.test_dataset,
schema=config.dataset_schema,
batch_size=config.batch_size, batch_size=config.batch_size,
sink_mode=config.dataset_sink_mode) if config.test_dataset else None sink_mode=config.dataset_sink_mode) if config.test_dataset else None
@ -341,7 +336,6 @@ if __name__ == '__main__':
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
_check_args(args.config) _check_args(args.config)
_config = get_config(args.config) _config = get_config(args.config)
_config.dataset_schema = args.dataset_schema_train
_config.pre_train_dataset = args.pre_train_dataset _config.pre_train_dataset = args.pre_train_dataset
set_seed(_config.random_seed) set_seed(_config.random_seed)
if _rank_size is not None and int(_rank_size) > 1: if _rank_size is not None and int(_rank_size) > 1:

@ -31,15 +31,11 @@ parser = argparse.ArgumentParser(description='GNMT train and eval.')
# train # train
parser.add_argument("--config_train", type=str, required=True, parser.add_argument("--config_train", type=str, required=True,
help="model config json file path.") help="model config json file path.")
parser.add_argument("--dataset_schema_train", type=str, required=True,
help="dataset schema for train.")
parser.add_argument("--pre_train_dataset", type=str, required=True, parser.add_argument("--pre_train_dataset", type=str, required=True,
help="pre-train dataset address.") help="pre-train dataset address.")
# eval # eval
parser.add_argument("--config_test", type=str, required=True, parser.add_argument("--config_test", type=str, required=True,
help="model config json file path.") help="model config json file path.")
parser.add_argument("--dataset_schema_test", type=str, required=True,
help="dataset schema for evaluation.")
parser.add_argument("--test_dataset", type=str, required=True, parser.add_argument("--test_dataset", type=str, required=True,
help="test dataset address.") help="test dataset address.")
parser.add_argument("--existed_ckpt", type=str, required=True, parser.add_argument("--existed_ckpt", type=str, required=True,
@ -77,7 +73,6 @@ if __name__ == '__main__':
# train # train
_check_args(args.config_train) _check_args(args.config_train)
_config_train = get_config(args.config_train) _config_train = get_config(args.config_train)
_config_train.dataset_schema = args.dataset_schema_train
_config_train.pre_train_dataset = args.pre_train_dataset _config_train.pre_train_dataset = args.pre_train_dataset
set_seed(_config_train.random_seed) set_seed(_config_train.random_seed)
assert _rank_size is not None and int(_rank_size) > 1 assert _rank_size is not None and int(_rank_size) > 1
@ -86,7 +81,6 @@ if __name__ == '__main__':
# eval # eval
_check_args(args.config_test) _check_args(args.config_test)
_config_test = get_config(args.config_test) _config_test = get_config(args.config_test)
_config_test.dataset_schema = args.dataset_schema_test
_config_test.test_dataset = args.test_dataset _config_test.test_dataset = args.test_dataset
_config_test.existed_ckpt = args.existed_ckpt _config_test.existed_ckpt = args.existed_ckpt
result = infer(_config_test) result = infer(_config_test)

@ -16,19 +16,15 @@
echo "==============================================================================================================" echo "=============================================================================================================="
echo "Please run the scipt as: " echo "Please run the scipt as: "
echo "sh run_distributed_train_ascend.sh \ echo "sh test_gnmt_v2.sh \
GNMT_ADDR RANK_TABLE_ADDR \ GNMT_ADDR RANK_TABLE_ADDR PRE_TRAIN_DATASET TEST_DATASET EXISTED_CKPT_PATH \
DATASET_SCHEMA_TRAIN PRE_TRAIN_DATASET \
DATASET_SCHEMA_TEST TEST_DATASET EXISTED_CKPT_PATH \
VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET" VOCAB_ADDR BPE_CODE_ADDR TEST_TARGET"
echo "for example:" echo "for example:"
echo "sh run_distributed_train_ascend.sh \ echo "sh test_gnmt_v2.sh \
/home/workspace/gnmt_v2 \ /home/workspace/gnmt_v2 \
/home/workspace/rank_table_8p.json \ /home/workspace/rank_table_8p.json \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json \ /home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.mindrecord \
/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001 \ /home/workspace/dataset_menu/newstest2014.en.mindrecord \
/home/workspace/dataset_menu/newstest2014.en.json \
/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001 \
/home/workspace/gnmt_v2/gnmt-6_3452.ckpt \ /home/workspace/gnmt_v2/gnmt-6_3452.ckpt \
/home/workspace/wmt16_de_en/vocab.bpe.32000 \ /home/workspace/wmt16_de_en/vocab.bpe.32000 \
/home/workspace/wmt16_de_en/bpe.32000 \ /home/workspace/wmt16_de_en/bpe.32000 \
@ -39,15 +35,13 @@ echo "==========================================================================
GNMT_ADDR=$1 GNMT_ADDR=$1
RANK_TABLE_ADDR=$2 RANK_TABLE_ADDR=$2
# train dataset addr # train dataset addr
DATASET_SCHEMA_TRAIN=$3 PRE_TRAIN_DATASET=$3
PRE_TRAIN_DATASET=$4
# eval dataset addr # eval dataset addr
DATASET_SCHEMA_TEST=$5 TEST_DATASET=$4
TEST_DATASET=$6 EXISTED_CKPT_PATH=$5
EXISTED_CKPT_PATH=$7 VOCAB_ADDR=$6
VOCAB_ADDR=$8 BPE_CODE_ADDR=$7
BPE_CODE_ADDR=$9 TEST_TARGET=$8
TEST_TARGET=${10}
current_exec_path=$(pwd) current_exec_path=$(pwd)
echo ${current_exec_path} echo ${current_exec_path}
@ -72,10 +66,8 @@ do
export DEVICE_ID=$i export DEVICE_ID=$i
python test_gnmt_v2.py \ python test_gnmt_v2.py \
--config_train=${GNMT_ADDR}/config/config.json \ --config_train=${GNMT_ADDR}/config/config.json \
--dataset_schema_train=$DATASET_SCHEMA_TRAIN \
--pre_train_dataset=$PRE_TRAIN_DATASET \ --pre_train_dataset=$PRE_TRAIN_DATASET \
--config_test=${GNMT_ADDR}/config/config_test.json \ --config_test=${GNMT_ADDR}/config/config_test.json \
--dataset_schema_test=$DATASET_SCHEMA_TEST \
--test_dataset=$TEST_DATASET \ --test_dataset=$TEST_DATASET \
--existed_ckpt=$EXISTED_CKPT_PATH \ --existed_ckpt=$EXISTED_CKPT_PATH \
--vocab=$VOCAB_ADDR \ --vocab=$VOCAB_ADDR \

Loading…
Cancel
Save