!8307 GNMT v2 network
From: @gaojing22 Reviewed-by: @yingjy,@yingjy,@c_34 Signed-off-by: @yingjypull/8307/MERGE
commit
b7fc189c2b
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""GNMTv2 model configuration."""
|
||||||
|
from .config import GNMTConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"GNMTConfig"
|
||||||
|
]
|
@ -0,0 +1,55 @@
|
|||||||
|
{
|
||||||
|
"training_platform": {
|
||||||
|
"modelarts": false
|
||||||
|
},
|
||||||
|
"dataset_config": {
|
||||||
|
"random_seed": 50,
|
||||||
|
"epochs": 6,
|
||||||
|
"batch_size": 128,
|
||||||
|
"dataset_schema": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.json",
|
||||||
|
"pre_train_dataset": "/home/workspace/dataset_menu/train.tok.clean.bpe.32000.en.tfrecord-001-of-001",
|
||||||
|
"fine_tune_dataset": null,
|
||||||
|
"test_dataset": null,
|
||||||
|
"valid_dataset": null,
|
||||||
|
"dataset_sink_mode": true,
|
||||||
|
"dataset_sink_step": 2
|
||||||
|
},
|
||||||
|
"model_config": {
|
||||||
|
"seq_length": 51,
|
||||||
|
"vocab_size": 32320,
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"hidden_dropout_prob": 0.2,
|
||||||
|
"attention_dropout_prob": 0.2,
|
||||||
|
"initializer_range": 0.1,
|
||||||
|
"label_smoothing": 0.1,
|
||||||
|
"beam_width": 2,
|
||||||
|
"length_penalty_weight": 0.6,
|
||||||
|
"max_decode_length": 50
|
||||||
|
},
|
||||||
|
"loss_scale_config": {
|
||||||
|
"init_loss_scale": 65536,
|
||||||
|
"loss_scale_factor": 2,
|
||||||
|
"scale_window": 1000
|
||||||
|
},
|
||||||
|
"learn_rate_config": {
|
||||||
|
"optimizer": "adam",
|
||||||
|
"lr": 2e-3,
|
||||||
|
"lr_scheduler": "WarmupMultiStepLR",
|
||||||
|
"lr_scheduler_power": 0.5,
|
||||||
|
"warmup_lr_remain_steps": 0.666,
|
||||||
|
"warmup_lr_decay_interval": -1,
|
||||||
|
"decay_steps": 4,
|
||||||
|
"decay_start_step": -1,
|
||||||
|
"warmup_steps": 200,
|
||||||
|
"min_lr": 1e-6
|
||||||
|
},
|
||||||
|
"checkpoint_options": {
|
||||||
|
"existed_ckpt": "",
|
||||||
|
"save_ckpt_steps": 3452,
|
||||||
|
"keep_ckpt_max": 6,
|
||||||
|
"ckpt_prefix": "gnmt",
|
||||||
|
"ckpt_path": "text_translation"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,228 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Configuration class for GNMT."""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
def _is_dataset_file(file: str):
|
||||||
|
return "tfrecord" in file.lower() or "mindrecord" in file.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_files_from_dir(folder: str):
|
||||||
|
_files = []
|
||||||
|
for file in os.listdir(folder):
|
||||||
|
if _is_dataset_file(file):
|
||||||
|
_files.append(os.path.join(folder, file))
|
||||||
|
return _files
|
||||||
|
|
||||||
|
|
||||||
|
def get_source_list(folder: str) -> List:
|
||||||
|
"""
|
||||||
|
Get file list from a folder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list, file list.
|
||||||
|
"""
|
||||||
|
_list = []
|
||||||
|
if not folder:
|
||||||
|
return _list
|
||||||
|
|
||||||
|
if os.path.isdir(folder):
|
||||||
|
_list = _get_files_from_dir(folder)
|
||||||
|
else:
|
||||||
|
if _is_dataset_file(folder):
|
||||||
|
_list.append(folder)
|
||||||
|
return _list
|
||||||
|
|
||||||
|
|
||||||
|
PARAM_NODES = {"dataset_config",
|
||||||
|
"training_platform",
|
||||||
|
"model_config",
|
||||||
|
"loss_scale_config",
|
||||||
|
"learn_rate_config",
|
||||||
|
"checkpoint_options"}
|
||||||
|
|
||||||
|
|
||||||
|
class GNMTConfig:
|
||||||
|
"""
|
||||||
|
Configuration for `GNMT`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random_seed (int): Random seed.
|
||||||
|
batch_size (int): Batch size of input dataset.
|
||||||
|
epochs (int): Epoch number.
|
||||||
|
dataset_sink_mode (bool): Whether enable dataset sink mode.
|
||||||
|
dataset_sink_step (int): Dataset sink step.
|
||||||
|
lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
|
||||||
|
lr (float): Initial learning rate.
|
||||||
|
min_lr (float): Minimum learning rate.
|
||||||
|
decay_start_step (int): Step to decay.
|
||||||
|
warmup_steps (int): Warm up steps.
|
||||||
|
dataset_schema (str): Path of dataset schema file.
|
||||||
|
pre_train_dataset (str): Path of pre-training dataset file or folder.
|
||||||
|
fine_tune_dataset (str): Path of fine-tune dataset file or folder.
|
||||||
|
test_dataset (str): Path of test dataset file or folder.
|
||||||
|
valid_dataset (str): Path of validation dataset file or folder.
|
||||||
|
ckpt_path (str): Checkpoints save path.
|
||||||
|
save_ckpt_steps (int): Interval of saving ckpt.
|
||||||
|
ckpt_prefix (str): Prefix of ckpt file.
|
||||||
|
keep_ckpt_max (int): Max ckpt files number.
|
||||||
|
seq_length (int): Length of input sequence. Default: 64.
|
||||||
|
vocab_size (int): The shape of each embedding vector. Default: 46192.
|
||||||
|
hidden_size (int): Size of embedding, attention, dim. Default: 512.
|
||||||
|
num_hidden_layers (int): Encoder, Decoder layers.
|
||||||
|
|
||||||
|
intermediate_size (int): Size of intermediate layer in the Transformer
|
||||||
|
encoder/decoder cell. Default: 4096.
|
||||||
|
hidden_act (str): Activation function used in the Transformer encoder/decoder
|
||||||
|
cell. Default: "relu".
|
||||||
|
init_loss_scale (int): Initialized loss scale.
|
||||||
|
loss_scale_factor (int): Loss scale factor.
|
||||||
|
scale_window (int): Window size of loss scale.
|
||||||
|
beam_width (int): Beam width for beam search in inferring. Default: 4.
|
||||||
|
length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
|
||||||
|
label_smoothing (float): Label smoothing setting. Default: 0.1.
|
||||||
|
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
|
||||||
|
dataset. Default: True.
|
||||||
|
save_graphs (bool): Whether to save graphs, please set to True if mindinsight
|
||||||
|
is wanted.
|
||||||
|
dtype (mstype): Data type of the input. Default: mstype.float32.
|
||||||
|
max_decode_length (int): Max decode length for inferring. Default: 64.
|
||||||
|
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
|
||||||
|
attention_dropout_prob (float): The dropout probability for
|
||||||
|
Multi-head Self-Attention. Default: 0.1.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
modelarts=False, random_seed=74,
|
||||||
|
epochs=6, batch_size=64,
|
||||||
|
dataset_schema: str = None,
|
||||||
|
pre_train_dataset: str = None,
|
||||||
|
fine_tune_dataset: str = None,
|
||||||
|
test_dataset: str = None,
|
||||||
|
valid_dataset: str = None,
|
||||||
|
dataset_sink_mode=True, dataset_sink_step=1,
|
||||||
|
seq_length=51, vocab_size=32320, hidden_size=1024,
|
||||||
|
num_hidden_layers=4, intermediate_size=4096,
|
||||||
|
hidden_act="tanh",
|
||||||
|
hidden_dropout_prob=0.2, attention_dropout_prob=0.2,
|
||||||
|
initializer_range=0.1,
|
||||||
|
label_smoothing=0.1,
|
||||||
|
beam_width=5,
|
||||||
|
length_penalty_weight=1.0,
|
||||||
|
max_decode_length=50,
|
||||||
|
input_mask_from_dataset=False,
|
||||||
|
init_loss_scale=2 ** 10,
|
||||||
|
loss_scale_factor=2, scale_window=128,
|
||||||
|
lr_scheduler="", optimizer="adam",
|
||||||
|
lr=1e-4, min_lr=1e-6,
|
||||||
|
decay_steps=4, lr_scheduler_power=1,
|
||||||
|
warmup_lr_remain_steps=0.666, warmup_lr_decay_interval=-1,
|
||||||
|
decay_start_step=-1, warmup_steps=200,
|
||||||
|
existed_ckpt="", save_ckpt_steps=2000, keep_ckpt_max=20,
|
||||||
|
ckpt_prefix="gnmt", ckpt_path: str = None,
|
||||||
|
save_step=10000,
|
||||||
|
save_graphs=False,
|
||||||
|
dtype=mstype.float32):
|
||||||
|
|
||||||
|
self.save_graphs = save_graphs
|
||||||
|
self.random_seed = random_seed
|
||||||
|
self.modelarts = modelarts
|
||||||
|
self.save_step = save_step
|
||||||
|
self.dataset_schema = dataset_schema
|
||||||
|
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
|
||||||
|
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
|
||||||
|
self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
|
||||||
|
self.test_dataset = get_source_list(test_dataset) # type: List[str]
|
||||||
|
|
||||||
|
if not isinstance(epochs, int) and epochs < 0:
|
||||||
|
raise ValueError("`epoch` must be type of int.")
|
||||||
|
|
||||||
|
self.epochs = epochs
|
||||||
|
self.dataset_sink_mode = dataset_sink_mode
|
||||||
|
self.dataset_sink_step = dataset_sink_step
|
||||||
|
|
||||||
|
self.ckpt_path = ckpt_path
|
||||||
|
self.keep_ckpt_max = keep_ckpt_max
|
||||||
|
self.save_ckpt_steps = save_ckpt_steps
|
||||||
|
self.ckpt_prefix = ckpt_prefix
|
||||||
|
self.existed_ckpt = existed_ckpt
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_dropout_prob = attention_dropout_prob
|
||||||
|
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
|
self.beam_width = beam_width
|
||||||
|
self.length_penalty_weight = length_penalty_weight
|
||||||
|
self.max_decode_length = max_decode_length
|
||||||
|
self.input_mask_from_dataset = input_mask_from_dataset
|
||||||
|
self.compute_type = mstype.float16
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.scale_window = scale_window
|
||||||
|
self.loss_scale_factor = loss_scale_factor
|
||||||
|
self.init_loss_scale = init_loss_scale
|
||||||
|
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.lr = lr
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
self.min_lr = min_lr
|
||||||
|
self.lr_scheduler_power = lr_scheduler_power
|
||||||
|
self.warmup_lr_remain_steps = warmup_lr_remain_steps
|
||||||
|
self.warmup_lr_decay_interval = warmup_lr_decay_interval
|
||||||
|
self.decay_steps = decay_steps
|
||||||
|
self.decay_start_step = decay_start_step
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
|
||||||
|
self.train_url = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, json_object: dict):
|
||||||
|
"""Constructs a `TransformerConfig` from a Python dictionary of parameters."""
|
||||||
|
_params = {}
|
||||||
|
for node in PARAM_NODES:
|
||||||
|
for key in json_object[node]:
|
||||||
|
_params[key] = json_object[node][key]
|
||||||
|
return cls(**_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_file(cls, json_file):
|
||||||
|
"""Constructs a `TransformerConfig` from a json file of parameters."""
|
||||||
|
with open(json_file, "r") as reader:
|
||||||
|
return cls.from_dict(json.load(reader))
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Serializes this instance to a Python dictionary."""
|
||||||
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def to_json_string(self):
|
||||||
|
"""Serializes this instance to a JSON string."""
|
||||||
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
@ -0,0 +1,55 @@
|
|||||||
|
{
|
||||||
|
"training_platform": {
|
||||||
|
"modelarts": false
|
||||||
|
},
|
||||||
|
"dataset_config": {
|
||||||
|
"random_seed": 50,
|
||||||
|
"epochs": 6,
|
||||||
|
"batch_size": 128,
|
||||||
|
"dataset_schema": "/home/workspace/dataset_menu/newstest2014.en.json",
|
||||||
|
"pre_train_dataset": null,
|
||||||
|
"fine_tune_dataset": null,
|
||||||
|
"test_dataset": "/home/workspace/dataset_menu/newstest2014.en.tfrecord-001-of-001",
|
||||||
|
"valid_dataset": null,
|
||||||
|
"dataset_sink_mode": true,
|
||||||
|
"dataset_sink_step": 2
|
||||||
|
},
|
||||||
|
"model_config": {
|
||||||
|
"seq_length": 107,
|
||||||
|
"vocab_size": 32320,
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"num_hidden_layers": 4,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"hidden_dropout_prob": 0.2,
|
||||||
|
"attention_dropout_prob": 0.2,
|
||||||
|
"initializer_range": 0.1,
|
||||||
|
"label_smoothing": 0.1,
|
||||||
|
"beam_width": 2,
|
||||||
|
"length_penalty_weight": 0.6,
|
||||||
|
"max_decode_length": 80
|
||||||
|
},
|
||||||
|
"loss_scale_config": {
|
||||||
|
"init_loss_scale": 8192,
|
||||||
|
"loss_scale_factor": 2,
|
||||||
|
"scale_window": 128
|
||||||
|
},
|
||||||
|
"learn_rate_config": {
|
||||||
|
"optimizer": "adam",
|
||||||
|
"lr": 2e-3,
|
||||||
|
"lr_scheduler": "WarmupMultiStepLR",
|
||||||
|
"lr_scheduler_power": 0.5,
|
||||||
|
"warmup_lr_remain_steps": 0.666,
|
||||||
|
"warmup_lr_decay_interval": -1,
|
||||||
|
"decay_steps": 4,
|
||||||
|
"decay_start_step": -1,
|
||||||
|
"warmup_steps": 200,
|
||||||
|
"min_lr": 1e-6
|
||||||
|
},
|
||||||
|
"checkpoint_options": {
|
||||||
|
"existed_ckpt": "/home/workspace/gnmt_v2/gnmt-6_3452.ckpt",
|
||||||
|
"save_ckpt_steps": 3452,
|
||||||
|
"keep_ckpt_max": 6,
|
||||||
|
"ckpt_prefix": "gnmt",
|
||||||
|
"ckpt_path": "text_translation"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,75 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Create Dataset."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from src.dataset.bi_data_loader import BiLingualDataLoader, TextDataLoader
|
||||||
|
from src.dataset.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Generate dataset file.')
|
||||||
|
parser.add_argument("--src_folder", type=str, default="/home/workspace/wmt16_de_en", required=False,
|
||||||
|
help="Raw corpus folder.")
|
||||||
|
|
||||||
|
parser.add_argument("--output_folder", type=str, default="/home/workspace/dataset_menu",
|
||||||
|
required=False,
|
||||||
|
help="Dataset output path.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
if not os.path.exists(args.output_folder):
|
||||||
|
os.makedirs(args.output_folder)
|
||||||
|
dicts = []
|
||||||
|
train_src_file = "train.tok.clean.bpe.32000.en"
|
||||||
|
train_tgt_file = "train.tok.clean.bpe.32000.de"
|
||||||
|
test_src_file = "newstest2014.en"
|
||||||
|
test_tgt_file = "newstest2014.de"
|
||||||
|
|
||||||
|
vocab = args.src_folder + "/vocab.bpe.32000"
|
||||||
|
bpe_codes = args.src_folder + "/bpe.32000"
|
||||||
|
pad_vocab = 8
|
||||||
|
tokenizer = Tokenizer(vocab, bpe_codes, src_en='en', tgt_de='de', vocab_pad=pad_vocab)
|
||||||
|
|
||||||
|
test = TextDataLoader(
|
||||||
|
src_filepath=os.path.join(args.src_folder, test_src_file),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
source_max_sen_len=None,
|
||||||
|
schema_address=args.output_folder + "/" + test_src_file + ".json"
|
||||||
|
)
|
||||||
|
print(f" | It's writing, please wait a moment.")
|
||||||
|
test.write_to_tfrecord(
|
||||||
|
path=os.path.join(
|
||||||
|
args.output_folder,
|
||||||
|
os.path.basename(test_src_file) + ".tfrecord"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
train = BiLingualDataLoader(
|
||||||
|
src_filepath=os.path.join(args.src_folder, train_src_file),
|
||||||
|
tgt_filepath=os.path.join(args.src_folder, train_tgt_file),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
source_max_sen_len=51,
|
||||||
|
target_max_sen_len=50,
|
||||||
|
schema_address=args.output_folder + "/" + train_src_file + ".json"
|
||||||
|
)
|
||||||
|
print(f" | It's writing, please wait a moment.")
|
||||||
|
train.write_to_tfrecord(
|
||||||
|
path=os.path.join(
|
||||||
|
args.output_folder,
|
||||||
|
os.path.basename(train_src_file) + ".tfrecord"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" | Vocabulary size: {tokenizer.vocab_size}.")
|
@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Evaluation api."""
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
from config import GNMTConfig
|
||||||
|
from src.gnmt_model import infer
|
||||||
|
from src.gnmt_model.bleu_calculate import bleu_calculate
|
||||||
|
from src.dataset.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='gnmt')
|
||||||
|
parser.add_argument("--config", type=str, required=True,
|
||||||
|
help="model config json file path.")
|
||||||
|
parser.add_argument("--vocab", type=str, required=True,
|
||||||
|
help="Vocabulary to use.")
|
||||||
|
parser.add_argument("--bpe_codes", type=str, required=True,
|
||||||
|
help="bpe codes to use.")
|
||||||
|
parser.add_argument("--test_tgt", type=str, required=False,
|
||||||
|
default=None,
|
||||||
|
help="data file of the test target")
|
||||||
|
parser.add_argument("--output", type=str, required=False,
|
||||||
|
default="./output.npz",
|
||||||
|
help="result file path.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(config):
|
||||||
|
config = GNMTConfig.from_json_file(config)
|
||||||
|
config.compute_type = mstype.float16
|
||||||
|
config.dtype = mstype.float32
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
_config = get_config(args.config)
|
||||||
|
result = infer(_config)
|
||||||
|
|
||||||
|
with open(args.output, "wb") as f:
|
||||||
|
pickle.dump(result, f, 1)
|
||||||
|
|
||||||
|
result_npy_addr = args.output
|
||||||
|
vocab = args.vocab
|
||||||
|
bpe_codes = args.bpe_codes
|
||||||
|
test_tgt = args.test_tgt
|
||||||
|
tokenizer = Tokenizer(vocab, bpe_codes, 'en', 'de')
|
||||||
|
scores = bleu_calculate(tokenizer, result_npy_addr, test_tgt)
|
||||||
|
print(f"BLEU scores is :{scores}")
|
@ -0,0 +1,6 @@
|
|||||||
|
nltk
|
||||||
|
jieba
|
||||||
|
numpy
|
||||||
|
subword-nmt==0.3.7
|
||||||
|
sacrebleu==1.2.10
|
||||||
|
sacremoses==0.0.19
|
@ -0,0 +1,39 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
export RANK_TABLE_FILE=/home/workspace/rank_table_8p.json
|
||||||
|
export MINDSPORE_HCCL_CONFIG_PATH=/home/workspace/rank_table_8p.json
|
||||||
|
|
||||||
|
echo $RANK_TABLE_FILE
|
||||||
|
export RANK_SIZE=8
|
||||||
|
|
||||||
|
for((i=0;i<=7;i++));
|
||||||
|
do
|
||||||
|
rm -rf ${current_exec_path}/device$i
|
||||||
|
mkdir ${current_exec_path}/device$i
|
||||||
|
cd ${current_exec_path}/device$i || exit
|
||||||
|
cp ../../*.py .
|
||||||
|
cp ../../*.sh .
|
||||||
|
cp -r ../../src .
|
||||||
|
cp -r ../../config .
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
python ../../train.py --config /home/workspace/gnmt_v2/config/config.json > log_gnmt_network${i}.log 2>&1 &
|
||||||
|
cd ${current_exec_path} || exit
|
||||||
|
done
|
||||||
|
cd ${current_exec_path} || exit
|
@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=5
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ../*.py ./eval
|
||||||
|
cp -r ../src ./eval
|
||||||
|
cp -r ../config ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
echo "start eval for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python eval.py --config /home/workspace/gnmt_v2/config/config_test.json --vocab /home/workspace/wmt16_de_en/vocab.bpe.32000 --bpe_codes /home/workspace/wmt16_de_en/bpe.32000 --test_tgt /home/workspace/wmt16_de_en/newstest2014.de >log_infer.log 2>&1 &
|
||||||
|
cd ..
|
@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
export DEVICE_NUM=1
|
||||||
|
export DEVICE_ID=4
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ../*.py ./train
|
||||||
|
cp -r ../src ./train
|
||||||
|
cp -r ../config ./train
|
||||||
|
cd ./train || exit
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python train.py --config /home/workspace/gnmt_v2/config/config.json > log_gnmt_network.log 2>&1 &
|
||||||
|
cd ..
|
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""GNMTv2 Init."""
|
||||||
|
from .dataset import load_dataset
|
||||||
|
from .dataset import bi_data_loader
|
||||||
|
from .gnmt_model import GNMT, infer, GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
|
||||||
|
from .gnmt_model import LabelSmoothedCrossEntropyCriterion
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_dataset",
|
||||||
|
"bi_data_loader",
|
||||||
|
"GNMT",
|
||||||
|
"infer",
|
||||||
|
"GNMTNetworkWithLoss",
|
||||||
|
"GNMTTrainOneStepWithLossScaleCell",
|
||||||
|
"LabelSmoothedCrossEntropyCriterion"
|
||||||
|
]
|
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Dataset Init."""
|
||||||
|
from .bi_data_loader import BiLingualDataLoader, TextDataLoader
|
||||||
|
from .load_dataset import load_dataset
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_dataset",
|
||||||
|
"BiLingualDataLoader",
|
||||||
|
"TextDataLoader",
|
||||||
|
"Tokenizer"
|
||||||
|
]
|
@ -0,0 +1,102 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Base class of data loader."""
|
||||||
|
import os
|
||||||
|
import collections
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore.mindrecord import FileWriter
|
||||||
|
from .schema import SCHEMA
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoader:
|
||||||
|
"""Data loader for dataset."""
|
||||||
|
_SCHEMA = SCHEMA
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._examples = []
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def padding(self, sen, padding_idx, need_sentence_len=None, dtype=np.int64):
|
||||||
|
"""Padding <pad> to sentence."""
|
||||||
|
if need_sentence_len is None:
|
||||||
|
return None
|
||||||
|
if sen.shape[0] > need_sentence_len:
|
||||||
|
return None
|
||||||
|
new_sen = np.array([padding_idx] * need_sentence_len, dtype=dtype)
|
||||||
|
new_sen[:sen.shape[0]] = sen[:]
|
||||||
|
return new_sen
|
||||||
|
|
||||||
|
def write_to_mindrecord(self, path, shard_num=1, desc=""):
|
||||||
|
"""
|
||||||
|
Write mindrecord file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): File path.
|
||||||
|
shard_num (int): Shard num.
|
||||||
|
desc (str): Description.
|
||||||
|
"""
|
||||||
|
if not os.path.isabs(path):
|
||||||
|
path = os.path.abspath(path)
|
||||||
|
|
||||||
|
writer = FileWriter(file_name=path, shard_num=shard_num)
|
||||||
|
writer.add_schema(self._SCHEMA, desc)
|
||||||
|
if not self._examples:
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
writer.write_raw_data(self._examples)
|
||||||
|
writer.commit()
|
||||||
|
print(f"| Wrote to {path}.")
|
||||||
|
|
||||||
|
def write_to_tfrecord(self, path, shard_num=1):
|
||||||
|
"""
|
||||||
|
Write to tfrecord.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): Output file path.
|
||||||
|
shard_num (int): Shard num.
|
||||||
|
"""
|
||||||
|
import tensorflow as tf
|
||||||
|
if not os.path.isabs(path):
|
||||||
|
path = os.path.abspath(path)
|
||||||
|
output_files = []
|
||||||
|
for i in range(shard_num):
|
||||||
|
output_file = path + "-%03d-of-%03d" % (i + 1, shard_num)
|
||||||
|
output_files.append(output_file)
|
||||||
|
# create writers
|
||||||
|
writers = []
|
||||||
|
for output_file in output_files:
|
||||||
|
writers.append(tf.io.TFRecordWriter(output_file))
|
||||||
|
|
||||||
|
if not self._examples:
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
# create feature
|
||||||
|
features = collections.OrderedDict()
|
||||||
|
for example in self._examples:
|
||||||
|
for key in example:
|
||||||
|
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=example[key].tolist()))
|
||||||
|
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||||
|
for writer in writers:
|
||||||
|
writer.write(tf_example.SerializeToString())
|
||||||
|
for writer in writers:
|
||||||
|
writer.close()
|
||||||
|
for p in output_files:
|
||||||
|
print(f" | Write to {p}.")
|
||||||
|
|
||||||
|
def _add_example(self, example):
|
||||||
|
self._examples.append(example)
|
@ -0,0 +1,233 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Bilingual data loader."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import DataLoader
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class BiLingualDataLoader(DataLoader):
|
||||||
|
"""Loader for bilingual data."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
src_filepath: str,
|
||||||
|
tgt_filepath: str,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
min_sen_len=0,
|
||||||
|
source_max_sen_len=None,
|
||||||
|
target_max_sen_len=80,
|
||||||
|
schema_address=None):
|
||||||
|
super(BiLingualDataLoader, self).__init__()
|
||||||
|
self._src_filepath = src_filepath
|
||||||
|
self._tgt_filepath = tgt_filepath
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.min_sen_len = min_sen_len
|
||||||
|
self.source_max_sen_len = source_max_sen_len
|
||||||
|
self.target_max_sen_len = target_max_sen_len
|
||||||
|
self.schema_address = schema_address
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
count = 0
|
||||||
|
if self.source_max_sen_len is None:
|
||||||
|
with open(self._src_filepath, "r") as _src_file:
|
||||||
|
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||||
|
max_src = 0
|
||||||
|
for _, _pair in enumerate(_src_file):
|
||||||
|
src_tokens = [
|
||||||
|
int(self.tokenizer.tok2idx[t])
|
||||||
|
for t in _pair.strip().split(" ") if t
|
||||||
|
]
|
||||||
|
src_len = len(src_tokens)
|
||||||
|
if src_len > max_src:
|
||||||
|
max_src = src_len
|
||||||
|
self.source_max_sen_len = max_src + 2
|
||||||
|
|
||||||
|
if self.target_max_sen_len is None:
|
||||||
|
with open(self._src_filepath, "r") as _tgt_file:
|
||||||
|
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||||
|
max_tgt = 0
|
||||||
|
for _, _pair in enumerate(_tgt_file):
|
||||||
|
src_tokens = [
|
||||||
|
int(self.tokenizer.tok2idx[t])
|
||||||
|
for t in _pair.strip().split(" ") if t
|
||||||
|
]
|
||||||
|
tgt_len = len(src_tokens)
|
||||||
|
if tgt_len > max_tgt:
|
||||||
|
max_tgt = tgt_len
|
||||||
|
self.target_max_sen_len = max_tgt + 1
|
||||||
|
|
||||||
|
with open(self._src_filepath, "r") as _src_file:
|
||||||
|
print(f" | Processing corpus {self._src_filepath}.")
|
||||||
|
print(f" | Processing corpus {self._tgt_filepath}.")
|
||||||
|
with open(self._tgt_filepath, "r") as _tgt_file:
|
||||||
|
for _, _pair in enumerate(zip(_src_file, _tgt_file)):
|
||||||
|
|
||||||
|
src_tokens = [
|
||||||
|
int(self.tokenizer.tok2idx[t])
|
||||||
|
for t in _pair[0].strip().split(" ") if t
|
||||||
|
]
|
||||||
|
tgt_tokens = [
|
||||||
|
int(self.tokenizer.tok2idx[t])
|
||||||
|
for t in _pair[1].strip().split(" ") if t
|
||||||
|
]
|
||||||
|
src_tokens.insert(0, self.tokenizer.bos_index)
|
||||||
|
src_tokens.append(self.tokenizer.eos_index)
|
||||||
|
tgt_tokens.insert(0, self.tokenizer.bos_index)
|
||||||
|
tgt_tokens.append(self.tokenizer.eos_index)
|
||||||
|
src_tokens = np.array(src_tokens)
|
||||||
|
tgt_tokens = np.array(tgt_tokens)
|
||||||
|
src_len = src_tokens.shape[0]
|
||||||
|
tgt_len = tgt_tokens.shape[0]
|
||||||
|
|
||||||
|
if (src_len > self.source_max_sen_len) or (src_len < self.min_sen_len) or (
|
||||||
|
tgt_len > (self.target_max_sen_len + 1)) or (tgt_len < self.min_sen_len):
|
||||||
|
print(f"+++++ delete! src_len={src_len}, tgt_len={tgt_len - 1}, "
|
||||||
|
f"source_max_sen_len={self.source_max_sen_len},"
|
||||||
|
f"target_max_sen_len={self.target_max_sen_len}")
|
||||||
|
continue
|
||||||
|
# encoder inputs
|
||||||
|
encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)
|
||||||
|
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
|
||||||
|
for i in range(src_len):
|
||||||
|
src_padding[i] = 1
|
||||||
|
src_length = np.array([src_len], dtype=np.int64)
|
||||||
|
# decoder inputs
|
||||||
|
decoder_input = self.padding(tgt_tokens[:-1], self.tokenizer.padding_index, self.target_max_sen_len)
|
||||||
|
# decoder outputs
|
||||||
|
decoder_output = self.padding(tgt_tokens[1:], self.tokenizer.padding_index, self.target_max_sen_len)
|
||||||
|
tgt_padding = np.zeros(shape=self.target_max_sen_len + 1, dtype=np.int64)
|
||||||
|
for j in range(tgt_len):
|
||||||
|
tgt_padding[j] = 1
|
||||||
|
tgt_padding = tgt_padding[1:]
|
||||||
|
decoder_input = np.array(decoder_input, dtype=np.int64)
|
||||||
|
decoder_output = np.array(decoder_output, dtype=np.int64)
|
||||||
|
tgt_padding = np.array(tgt_padding, dtype=np.int64)
|
||||||
|
|
||||||
|
example = {
|
||||||
|
"src": encoder_input,
|
||||||
|
"src_padding": src_padding,
|
||||||
|
"src_length": src_length,
|
||||||
|
"prev_opt": decoder_input,
|
||||||
|
"target": decoder_output,
|
||||||
|
"tgt_padding": tgt_padding
|
||||||
|
}
|
||||||
|
self._add_example(example)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(f" | source padding_len = {self.source_max_sen_len}.")
|
||||||
|
print(f" | target padding_len = {self.target_max_sen_len}.")
|
||||||
|
print(f" | Total activate sen = {count}.")
|
||||||
|
print(f" | Total sen = {count}.")
|
||||||
|
|
||||||
|
if self.schema_address is not None:
|
||||||
|
provlist = [count, self.source_max_sen_len, self.source_max_sen_len, 1,
|
||||||
|
self.target_max_sen_len, self.target_max_sen_len, self.target_max_sen_len]
|
||||||
|
columns = ["src", "src_padding", "src_length", "prev_opt", "target", "tgt_padding"]
|
||||||
|
with open(self.schema_address, "w", encoding="utf-8") as f:
|
||||||
|
f.write("{\n")
|
||||||
|
f.write(' "datasetType":"TF",\n')
|
||||||
|
f.write(' "numRows":%s,\n' % provlist[0])
|
||||||
|
f.write(' "columns":{\n')
|
||||||
|
t = 1
|
||||||
|
for name in columns:
|
||||||
|
f.write(' "%s":{\n' % name)
|
||||||
|
f.write(' "type":"int64",\n')
|
||||||
|
f.write(' "rank":1,\n')
|
||||||
|
f.write(' "shape":[%s]\n' % provlist[t])
|
||||||
|
f.write(' }')
|
||||||
|
if t < len(columns):
|
||||||
|
f.write(',')
|
||||||
|
f.write('\n')
|
||||||
|
t += 1
|
||||||
|
f.write(' }\n}\n')
|
||||||
|
print(" | Write to " + self.schema_address)
|
||||||
|
|
||||||
|
|
||||||
|
class TextDataLoader(DataLoader):
|
||||||
|
"""Loader for text data."""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
src_filepath: str,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
min_sen_len=0,
|
||||||
|
source_max_sen_len=None,
|
||||||
|
schema_address=None):
|
||||||
|
super(TextDataLoader, self).__init__()
|
||||||
|
self._src_filepath = src_filepath
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.min_sen_len = min_sen_len
|
||||||
|
self.source_max_sen_len = source_max_sen_len
|
||||||
|
self.schema_address = schema_address
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
count = 0
|
||||||
|
if self.source_max_sen_len is None:
|
||||||
|
with open(self._src_filepath, "r") as _src_file:
|
||||||
|
print(f" | count the max_sen_len of corpus {self._src_filepath}.")
|
||||||
|
max_src = 0
|
||||||
|
for _, _pair in enumerate(_src_file):
|
||||||
|
src_tokens = self.tokenizer.tokenize(_pair)
|
||||||
|
src_len = len(src_tokens)
|
||||||
|
if src_len > max_src:
|
||||||
|
max_src = src_len
|
||||||
|
self.source_max_sen_len = max_src
|
||||||
|
|
||||||
|
with open(self._src_filepath, "r") as _src_file:
|
||||||
|
print(f" | Processing corpus {self._src_filepath}.")
|
||||||
|
for _, _pair in enumerate(_src_file):
|
||||||
|
src_tokens = self.tokenizer.tokenize(_pair)
|
||||||
|
src_len = len(src_tokens)
|
||||||
|
src_tokens = np.array(src_tokens)
|
||||||
|
# encoder inputs
|
||||||
|
encoder_input = self.padding(src_tokens, self.tokenizer.padding_index, self.source_max_sen_len)
|
||||||
|
src_padding = np.zeros(shape=self.source_max_sen_len, dtype=np.int64)
|
||||||
|
for i in range(src_len):
|
||||||
|
src_padding[i] = 1
|
||||||
|
src_length = np.array([src_len], dtype=np.int64)
|
||||||
|
|
||||||
|
example = {
|
||||||
|
"src": encoder_input,
|
||||||
|
"src_padding": src_padding,
|
||||||
|
"src_length": src_length
|
||||||
|
}
|
||||||
|
self._add_example(example)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print(f" | source padding_len = {self.source_max_sen_len}.")
|
||||||
|
print(f" | Total activate sen = {count}.")
|
||||||
|
print(f" | Total sen = {count}.")
|
||||||
|
|
||||||
|
if self.schema_address is not None:
|
||||||
|
provlist = [count, self.source_max_sen_len, self.source_max_sen_len, 1]
|
||||||
|
columns = ["src", "src_padding", "src_length"]
|
||||||
|
with open(self.schema_address, "w", encoding="utf-8") as f:
|
||||||
|
f.write("{\n")
|
||||||
|
f.write(' "datasetType":"TF",\n')
|
||||||
|
f.write(' "numRows":%s,\n' % provlist[0])
|
||||||
|
f.write(' "columns":{\n')
|
||||||
|
t = 1
|
||||||
|
for name in columns:
|
||||||
|
f.write(' "%s":{\n' % name)
|
||||||
|
f.write(' "type":"int64",\n')
|
||||||
|
f.write(' "rank":1,\n')
|
||||||
|
f.write(' "shape":[%s]\n' % provlist[t])
|
||||||
|
f.write(' }')
|
||||||
|
if t < len(columns):
|
||||||
|
f.write(',')
|
||||||
|
f.write('\n')
|
||||||
|
t += 1
|
||||||
|
f.write(' }\n}\n')
|
||||||
|
print(" | Write to " + self.schema_address)
|
@ -0,0 +1,147 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Dataset loader to feed into model."""
|
||||||
|
import os
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as deC
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dataset(input_files, schema_file, batch_size, epoch_count=1,
|
||||||
|
sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True,
|
||||||
|
drop_remainder=True, is_translate=False):
|
||||||
|
"""
|
||||||
|
Load dataset according to passed in params.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_files (list): Data files.
|
||||||
|
schema_file (str): Schema file path.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
epoch_count (int): Epoch count.
|
||||||
|
sink_mode (bool): Whether enable sink mode.
|
||||||
|
sink_step (int): Step to sink.
|
||||||
|
rank_size (int): Rank size.
|
||||||
|
rank_id (int): Rank id.
|
||||||
|
shuffle (bool): Whether shuffle dataset.
|
||||||
|
drop_remainder (bool): Whether drop the last possibly incomplete batch.
|
||||||
|
is_translate (bool): Whether translate the text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset, dataset instance.
|
||||||
|
"""
|
||||||
|
if not input_files:
|
||||||
|
raise FileNotFoundError("Require at least one dataset.")
|
||||||
|
|
||||||
|
if not (schema_file and
|
||||||
|
os.path.exists(schema_file)
|
||||||
|
and os.path.isfile(schema_file)
|
||||||
|
and os.path.basename(schema_file).endswith(".json")):
|
||||||
|
raise FileNotFoundError("`dataset_schema` must be a existed json file.")
|
||||||
|
|
||||||
|
if not isinstance(sink_mode, bool):
|
||||||
|
raise ValueError("`sink` must be type of bool.")
|
||||||
|
|
||||||
|
for datafile in input_files:
|
||||||
|
print(f" | Loading {datafile}.")
|
||||||
|
|
||||||
|
if not is_translate:
|
||||||
|
ds = de.TFRecordDataset(
|
||||||
|
input_files, schema_file,
|
||||||
|
columns_list=[
|
||||||
|
"src", "src_padding", "src_length",
|
||||||
|
"prev_opt",
|
||||||
|
"target", "tgt_padding"
|
||||||
|
],
|
||||||
|
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
|
||||||
|
shard_equal_rows=True, num_parallel_workers=8)
|
||||||
|
|
||||||
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
|
print(f" | Dataset size: {ori_dataset_size}.")
|
||||||
|
|
||||||
|
type_cast_op = deC.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns="src_length", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns="prev_opt", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns="target", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns="tgt_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
|
||||||
|
ds = ds.rename(
|
||||||
|
input_columns=["src",
|
||||||
|
"src_padding",
|
||||||
|
"src_length",
|
||||||
|
"prev_opt",
|
||||||
|
"target",
|
||||||
|
"tgt_padding"],
|
||||||
|
output_columns=["source_eos_ids",
|
||||||
|
"source_eos_mask",
|
||||||
|
"source_eos_length",
|
||||||
|
"target_sos_ids",
|
||||||
|
"target_eos_ids",
|
||||||
|
"target_eos_mask"]
|
||||||
|
)
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||||
|
else:
|
||||||
|
ds = de.TFRecordDataset(
|
||||||
|
input_files, schema_file,
|
||||||
|
columns_list=[
|
||||||
|
"src", "src_padding", "src_length"
|
||||||
|
],
|
||||||
|
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
|
||||||
|
shard_equal_rows=True, num_parallel_workers=8)
|
||||||
|
|
||||||
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
|
print(f" | Dataset size: {ori_dataset_size}.")
|
||||||
|
|
||||||
|
type_cast_op = deC.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(input_columns="src", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns="src_padding", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
ds = ds.map(input_columns="src_length", operations=type_cast_op, num_parallel_workers=8)
|
||||||
|
|
||||||
|
ds = ds.rename(
|
||||||
|
input_columns=["src",
|
||||||
|
"src_padding",
|
||||||
|
"src_length"],
|
||||||
|
output_columns=["source_eos_ids",
|
||||||
|
"source_eos_mask",
|
||||||
|
"source_eos_length"]
|
||||||
|
)
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=drop_remainder)
|
||||||
|
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(data_files: list, schema: str, batch_size: int, epoch_count: int, sink_mode: bool, sink_step: int = 1,
|
||||||
|
rank_size: int = 1, rank_id: int = 0, shuffle=True, drop_remainder=True, is_translate=False):
|
||||||
|
"""
|
||||||
|
Load dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_files (list): Data files.
|
||||||
|
schema (str): Schema file path.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
epoch_count (int): Epoch count.
|
||||||
|
sink_mode (bool): Whether enable sink mode.
|
||||||
|
sink_step (int): Step to sink.
|
||||||
|
rank_size (int): Rank size.
|
||||||
|
rank_id (int): Rank id.
|
||||||
|
shuffle (bool): Whether shuffle dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset, dataset instance.
|
||||||
|
"""
|
||||||
|
return _load_dataset(data_files, schema, batch_size, epoch_count, sink_mode,
|
||||||
|
sink_step, rank_size, rank_id, shuffle=shuffle,
|
||||||
|
drop_remainder=drop_remainder, is_translate=is_translate)
|
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Define schema of mindrecord."""
|
||||||
|
|
||||||
|
SCHEMA = {
|
||||||
|
"src": {"type": "int64", "shape": [-1]},
|
||||||
|
"src_padding": {"type": "int64", "shape": [-1]},
|
||||||
|
"src_length": {"type": "int64", "shape": [-1]},
|
||||||
|
"prev_opt": {"type": "int64", "shape": [-1]},
|
||||||
|
"target": {"type": "int64", "shape": [-1]},
|
||||||
|
"tgt_padding": {"type": "int64", "shape": [-1]},
|
||||||
|
}
|
@ -0,0 +1,101 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Tokenizer."""
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from functools import partial
|
||||||
|
import subword_nmt.apply_bpe
|
||||||
|
import sacremoses
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer:
|
||||||
|
"""
|
||||||
|
Tokenizer class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vocab_address=None, bpe_code_address=None,
|
||||||
|
src_en='en', tgt_de='de', vocab_pad=8, isolator='@@'):
|
||||||
|
"""
|
||||||
|
Constructor for the Tokenizer class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_address: vocabulary address.
|
||||||
|
bpe_code_address: path to the file with bpe codes.
|
||||||
|
vocab_pad: pads vocabulary to a multiple of 'vocab_pad' tokens.
|
||||||
|
isolator: tokenization isolator.
|
||||||
|
"""
|
||||||
|
self.padding_index = 0
|
||||||
|
self.unk_index = 1
|
||||||
|
self.bos_index = 2
|
||||||
|
self.eos_index = 3
|
||||||
|
self.pad_word = '<pad>'
|
||||||
|
self.unk_word = '<unk>'
|
||||||
|
self.bos_word = '<s>'
|
||||||
|
self.eos_word = r'<\s>'
|
||||||
|
self.isolator = isolator
|
||||||
|
self.init_bpe(bpe_code_address)
|
||||||
|
self.vocab_establist(vocab_address, vocab_pad)
|
||||||
|
self.sacremoses_tokenizer = sacremoses.MosesTokenizer(src_en)
|
||||||
|
self.sacremoses_detokenizer = sacremoses.MosesDetokenizer(tgt_de)
|
||||||
|
|
||||||
|
def init_bpe(self, bpe_code_address):
|
||||||
|
"""Init bpe."""
|
||||||
|
if (bpe_code_address is not None) and os.path.exists(bpe_code_address):
|
||||||
|
with open(bpe_code_address, 'r') as f1:
|
||||||
|
self.bpe = subword_nmt.apply_bpe.BPE(f1)
|
||||||
|
|
||||||
|
def vocab_establist(self, vocab_address, vocab_pad):
|
||||||
|
"""Establish vocabulary."""
|
||||||
|
if (vocab_address is None) or (not os.path.exists(vocab_address)):
|
||||||
|
return
|
||||||
|
vocab_words = [self.pad_word, self.unk_word, self.bos_word, self.eos_word]
|
||||||
|
with open(vocab_address) as f1:
|
||||||
|
for sentence in f1:
|
||||||
|
vocab_words.append(sentence.strip())
|
||||||
|
vocab_size = len(vocab_words)
|
||||||
|
padded_vocab_size = (vocab_size + vocab_pad - 1) // vocab_pad * vocab_pad
|
||||||
|
for idx in range(0, padded_vocab_size - vocab_size):
|
||||||
|
fil_token = f'filled{idx:04d}'
|
||||||
|
vocab_words.append(fil_token)
|
||||||
|
self.vocab_size = len(vocab_words)
|
||||||
|
self.tok2idx = defaultdict(partial(int, self.unk_index))
|
||||||
|
for idx, token in enumerate(vocab_words):
|
||||||
|
self.tok2idx[token] = idx
|
||||||
|
self.idx2tok = {}
|
||||||
|
self.idx2tok = defaultdict(partial(str, ","))
|
||||||
|
for token, idx in self.tok2idx.items():
|
||||||
|
self.idx2tok[idx] = token
|
||||||
|
|
||||||
|
def tokenize(self, sentence):
|
||||||
|
"""Tokenize sentence."""
|
||||||
|
tokenized = self.sacremoses_tokenizer.tokenize(sentence, return_str=True)
|
||||||
|
bpe = self.bpe.process_line(tokenized)
|
||||||
|
sentence = bpe.strip().split()
|
||||||
|
inputs = [self.tok2idx[i] for i in sentence]
|
||||||
|
inputs = [self.bos_index] + inputs + [self.eos_index]
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def detokenize(self, indexes, gap=' '):
|
||||||
|
"""Detokenizes single sentence and removes token isolator characters."""
|
||||||
|
reconstruction_bpe = gap.join([self.idx2tok[idx] for idx in indexes])
|
||||||
|
reconstruction_bpe = reconstruction_bpe.replace(self.isolator + ' ', '')
|
||||||
|
reconstruction_bpe = reconstruction_bpe.replace(self.isolator, '')
|
||||||
|
reconstruction_bpe = reconstruction_bpe.replace(self.bos_word, '')
|
||||||
|
reconstruction_bpe = reconstruction_bpe.replace(self.eos_word, '')
|
||||||
|
reconstruction_bpe = reconstruction_bpe.replace(self.unk_word, '')
|
||||||
|
reconstruction_bpe = reconstruction_bpe.replace(self.pad_word, '')
|
||||||
|
reconstruction_bpe = reconstruction_bpe.strip()
|
||||||
|
reconstruction_words = self.sacremoses_detokenizer.detokenize(reconstruction_bpe.split())
|
||||||
|
return reconstruction_words
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""GNMTv2 Init."""
|
||||||
|
from config.config import GNMTConfig
|
||||||
|
from .gnmt import GNMT
|
||||||
|
from .attention import BahdanauAttention
|
||||||
|
from .gnmt_for_train import GNMTTraining, LabelSmoothedCrossEntropyCriterion, \
|
||||||
|
GNMTNetworkWithLoss, GNMTTrainOneStepWithLossScaleCell
|
||||||
|
from .gnmt_for_infer import infer
|
||||||
|
from .bleu_calculate import bleu_calculate
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"infer",
|
||||||
|
"GNMTTraining",
|
||||||
|
"LabelSmoothedCrossEntropyCriterion",
|
||||||
|
"GNMTTrainOneStepWithLossScaleCell",
|
||||||
|
"GNMTNetworkWithLoss",
|
||||||
|
"GNMT",
|
||||||
|
"BahdanauAttention",
|
||||||
|
"GNMTConfig",
|
||||||
|
"bleu_calculate"
|
||||||
|
]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,93 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Calculate the blue scores"""
|
||||||
|
import subprocess
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.dataset.tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_result_data(result_npy_addr):
|
||||||
|
# load the numpy to list.
|
||||||
|
result = np.load(result_npy_addr, allow_pickle=True)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_bleu_data(tokenizer: Tokenizer, result_npy_addr):
|
||||||
|
"""
|
||||||
|
Detokenizer the prediction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer (Tokenizer): tokenizer operations.
|
||||||
|
result_npy_addr (string): Path to the predict file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List, the predict text context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = load_result_data(result_npy_addr)
|
||||||
|
prediction_list = []
|
||||||
|
for _, info in enumerate(result):
|
||||||
|
# prediction detokenize
|
||||||
|
prediction = info["prediction"]
|
||||||
|
prediction_str = tokenizer.detokenize(prediction)
|
||||||
|
prediction_list.append(prediction_str)
|
||||||
|
|
||||||
|
return prediction_list
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_sacrebleu(predict_path, target_path):
|
||||||
|
"""
|
||||||
|
Calculate the BLEU scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predict_path (string): Path to the predict file.
|
||||||
|
target_path (string): Path to the target file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Float32, bleu scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sacrebleu_params = '--score-only -lc --tokenize intl'
|
||||||
|
sacrebleu = subprocess.run([f'sacrebleu --input {predict_path} \
|
||||||
|
{target_path} {sacrebleu_params}'],
|
||||||
|
stdout=subprocess.PIPE, shell=True)
|
||||||
|
bleu_scores = round(float(sacrebleu.stdout.strip()), 2)
|
||||||
|
return bleu_scores
|
||||||
|
|
||||||
|
|
||||||
|
def bleu_calculate(tokenizer, result_npy_addr, target_addr=None):
|
||||||
|
"""
|
||||||
|
Calculate the BLEU scores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer (Tokenizer): tokenizer operations.
|
||||||
|
result_npy_addr (string): Path to the predict file.
|
||||||
|
target_addr (string): Path to the target file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Float32, bleu scores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prediction = get_bleu_data(tokenizer, result_npy_addr)
|
||||||
|
print("predict:\n", prediction)
|
||||||
|
|
||||||
|
eval_path = './predict.txt'
|
||||||
|
with open(eval_path, 'w') as eval_file:
|
||||||
|
lines = [line + '\n' for line in prediction]
|
||||||
|
eval_file.writelines(lines)
|
||||||
|
reference_path = target_addr
|
||||||
|
bleu_scores = calculate_sacrebleu(eval_path, reference_path)
|
||||||
|
return bleu_scores
|
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Components of model."""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
|
||||||
|
class SaturateCast(nn.Cell):
|
||||||
|
"""Cast wrapper."""
|
||||||
|
|
||||||
|
def __init__(self, dst_type=mstype.float32):
|
||||||
|
super(SaturateCast, self).__init__()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.dst_type = dst_type
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return self.cast(x, self.dst_type)
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Cell):
|
||||||
|
"""
|
||||||
|
Do layer norm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): In channels number of layer norm.
|
||||||
|
return_2d (bool): Whether return 2d tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, in_channels=None, return_2d=False):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
self.return_2d = return_2d
|
||||||
|
self.layer_norm = nn.LayerNorm((in_channels,))
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.get_dtype = P.DType()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
self.get_shape = P.Shape()
|
||||||
|
|
||||||
|
def construct(self, input_tensor):
|
||||||
|
"""Do layer norm."""
|
||||||
|
shape = self.get_shape(input_tensor)
|
||||||
|
batch_size = shape[0]
|
||||||
|
max_len = shape[1]
|
||||||
|
embed_dim = shape[2]
|
||||||
|
|
||||||
|
output = self.reshape(input_tensor, (-1, embed_dim))
|
||||||
|
output = self.cast(output, mstype.float32)
|
||||||
|
output = self.layer_norm(output)
|
||||||
|
output = self.cast(output, self.get_dtype(input_tensor))
|
||||||
|
if not self.return_2d:
|
||||||
|
output = self.reshape(output, (batch_size, max_len, embed_dim))
|
||||||
|
return output
|
@ -0,0 +1,67 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Create attention block."""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import nn
|
||||||
|
|
||||||
|
from .attention import BahdanauAttention
|
||||||
|
|
||||||
|
|
||||||
|
class RecurrentAttention(nn.Cell):
|
||||||
|
"""
|
||||||
|
Constructor for the RecurrentAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size: number of features in input tensor.
|
||||||
|
context_size: number of features in output from encoder.
|
||||||
|
hidden_size: internal hidden size.
|
||||||
|
num_layers: number of layers in LSTM.
|
||||||
|
dropout: probability of dropout (on input to LSTM layer).
|
||||||
|
initializer_range: range for the uniform initializer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, shape (N, T, D).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
rnn,
|
||||||
|
is_training=True,
|
||||||
|
input_size=1024,
|
||||||
|
context_size=1024,
|
||||||
|
hidden_size=1024,
|
||||||
|
num_layers=1,
|
||||||
|
dropout=0.2,
|
||||||
|
initializer_range=0.1):
|
||||||
|
super(RecurrentAttention, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(keep_prob=1.0 - dropout)
|
||||||
|
self.rnn = rnn
|
||||||
|
self.attn = BahdanauAttention(is_training=is_training,
|
||||||
|
query_size=hidden_size,
|
||||||
|
key_size=hidden_size,
|
||||||
|
num_units=hidden_size,
|
||||||
|
normalize=True,
|
||||||
|
initializer_range=initializer_range,
|
||||||
|
compute_type=mstype.float16)
|
||||||
|
|
||||||
|
def construct(self, decoder_embedding, context_key, attention_mask=None, rnn_init_state=None):
|
||||||
|
# decoder_embedding: [t_q,N,D]
|
||||||
|
# context: [t_k,N,D]
|
||||||
|
# attention_mask: [N,t_k]
|
||||||
|
# [t_q,N,D]
|
||||||
|
decoder_embedding = self.dropout(decoder_embedding)
|
||||||
|
rnn_outputs, rnn_state = self.rnn(decoder_embedding, rnn_init_state)
|
||||||
|
# rnn_outputs:[t_q,b,D], attn_outputs:[t_q,b,D], scores:[b, t_q, t_k], rnn_state:tuple([2,b,D]).
|
||||||
|
attn_outputs, scores = self.attn(query=rnn_outputs, keys=context_key, attention_mask=attention_mask)
|
||||||
|
return rnn_outputs, attn_outputs, rnn_state, scores
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue