commit
7a64fb1948
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Apply bpe script."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from multiprocessing import Pool, cpu_count
|
||||||
|
|
||||||
|
from src.utils import Dictionary
|
||||||
|
from src.utils import bpe_encode
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Apply BPE.')
|
||||||
|
parser.add_argument("--codes", type=str, default="", required=True,
|
||||||
|
help="bpe codes path.")
|
||||||
|
parser.add_argument("--src_folder", type=str, default="", required=True,
|
||||||
|
help="raw corpus folder.")
|
||||||
|
parser.add_argument("--output_folder", type=str, default="", required=True,
|
||||||
|
help="encoded corpus output path.")
|
||||||
|
parser.add_argument("--prefix", type=str, default="", required=False,
|
||||||
|
help="Prefix of text file.")
|
||||||
|
parser.add_argument("--vocab_path", type=str, default="", required=True,
|
||||||
|
help="Generated vocabulary output path.")
|
||||||
|
parser.add_argument("--threshold", type=int, default=None, required=False,
|
||||||
|
help="Filter out words that frequency is lower than threshold.")
|
||||||
|
parser.add_argument("--processes", type=int, default=2, required=False,
|
||||||
|
help="Number of processes to use.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
if not (args.codes and args.src_folder and args.output_folder):
|
||||||
|
raise ValueError("Please enter required params.")
|
||||||
|
|
||||||
|
source_folder = args.src_folder
|
||||||
|
output_folder = args.output_folder
|
||||||
|
codes = args.codes
|
||||||
|
|
||||||
|
if not os.path.exists(codes):
|
||||||
|
raise FileNotFoundError("`--codes` is not existed.")
|
||||||
|
if not os.path.exists(source_folder) or not os.path.isdir(source_folder):
|
||||||
|
raise ValueError("`--src_folder` must be a dir and existed.")
|
||||||
|
if not os.path.exists(output_folder) or not os.path.isdir(output_folder):
|
||||||
|
raise ValueError("`--output_folder` must be a dir and existed.")
|
||||||
|
if not isinstance(args.prefix, str) or len(args.prefix) > 128:
|
||||||
|
raise ValueError("`--prefix` must be a str and len <= 128.")
|
||||||
|
if not isinstance(args.processes, int):
|
||||||
|
raise TypeError("`--processes` must be an integer.")
|
||||||
|
|
||||||
|
available_dict = []
|
||||||
|
args_groups = []
|
||||||
|
for file in os.listdir(source_folder):
|
||||||
|
if args.prefix and not file.startswith(args.prefix):
|
||||||
|
continue
|
||||||
|
if file.endswith(".txt"):
|
||||||
|
output_path = os.path.join(output_folder, file.replace(".txt", "_bpe.txt"))
|
||||||
|
dict_path = os.path.join(output_folder, file.replace(".txt", ".dict"))
|
||||||
|
available_dict.append(dict_path)
|
||||||
|
args_groups.append((codes, os.path.join(source_folder, file),
|
||||||
|
output_path, dict_path))
|
||||||
|
|
||||||
|
kernel_size = 1 if args.processes <= 0 else args.processes
|
||||||
|
kernel_size = min(kernel_size, cpu_count())
|
||||||
|
pool = Pool(kernel_size)
|
||||||
|
for arg in args_groups:
|
||||||
|
pool.apply_async(bpe_encode, args=arg)
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
||||||
|
|
||||||
|
vocab = Dictionary.load_from_text(available_dict)
|
||||||
|
if args.threshold is not None:
|
||||||
|
vocab = vocab.shrink(args.threshold)
|
||||||
|
vocab.persistence(args.vocab_path)
|
||||||
|
print(f" | Vocabulary Size: {len(vocab)}")
|
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""MASS model configuration."""
|
||||||
|
from .config import TransformerConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TransformerConfig"
|
||||||
|
]
|
@ -0,0 +1,243 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Configuration class for Transformer."""
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
|
||||||
|
def _is_dataset_file(file: str):
|
||||||
|
return "tfrecord" in file.lower() or "mindrecord" in file.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_files_from_dir(folder: str):
|
||||||
|
_files = []
|
||||||
|
for file in os.listdir(folder):
|
||||||
|
if _is_dataset_file(file):
|
||||||
|
_files.append(os.path.join(folder, file))
|
||||||
|
return _files
|
||||||
|
|
||||||
|
|
||||||
|
def get_source_list(folder: str) -> List:
|
||||||
|
"""
|
||||||
|
Get file list from a folder.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list, file list.
|
||||||
|
"""
|
||||||
|
_list = []
|
||||||
|
if not folder:
|
||||||
|
return _list
|
||||||
|
|
||||||
|
if os.path.isdir(folder):
|
||||||
|
_list = _get_files_from_dir(folder)
|
||||||
|
else:
|
||||||
|
if _is_dataset_file(folder):
|
||||||
|
_list.append(folder)
|
||||||
|
return _list
|
||||||
|
|
||||||
|
|
||||||
|
PARAM_NODES = {"dataset_config",
|
||||||
|
"model_config",
|
||||||
|
"loss_scale_config",
|
||||||
|
"learn_rate_config",
|
||||||
|
"checkpoint_options"}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerConfig:
|
||||||
|
"""
|
||||||
|
Configuration for `Transformer`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random_seed (int): Random seed.
|
||||||
|
batch_size (int): Batch size of input dataset.
|
||||||
|
epochs (int): Epoch number.
|
||||||
|
dataset_sink_mode (bool): Whether enable dataset sink mode.
|
||||||
|
dataset_sink_step (int): Dataset sink step.
|
||||||
|
lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
|
||||||
|
lr (float): Initial learning rate.
|
||||||
|
min_lr (float): Minimum learning rate.
|
||||||
|
decay_start_step (int): Step to decay.
|
||||||
|
warmup_steps (int): Warm up steps.
|
||||||
|
dataset_schema (str): Path of dataset schema file.
|
||||||
|
pre_train_dataset (str): Path of pre-training dataset file or folder.
|
||||||
|
fine_tune_dataset (str): Path of fine-tune dataset file or folder.
|
||||||
|
test_dataset (str): Path of test dataset file or folder.
|
||||||
|
valid_dataset (str): Path of validation dataset file or folder.
|
||||||
|
ckpt_path (str): Checkpoints save path.
|
||||||
|
save_ckpt_steps (int): Interval of saving ckpt.
|
||||||
|
ckpt_prefix (str): Prefix of ckpt file.
|
||||||
|
keep_ckpt_max (int): Max ckpt files number.
|
||||||
|
seq_length (int): Length of input sequence. Default: 64.
|
||||||
|
vocab_size (int): The shape of each embedding vector. Default: 46192.
|
||||||
|
hidden_size (int): Size of embedding, attention, dim. Default: 512.
|
||||||
|
num_hidden_layers (int): Encoder, Decoder layers.
|
||||||
|
ngram (int): Number of tokens to predict ahead. Default: 2.
|
||||||
|
accumulation_steps (int): Number of steps to hold until next gradient optimization. Default: 1.
|
||||||
|
num_attention_heads (int): Number of hidden layers in the Transformer encoder/decoder
|
||||||
|
cell. Default: 6.
|
||||||
|
intermediate_size (int): Size of intermediate layer in the Transformer
|
||||||
|
encoder/decoder cell. Default: 4096.
|
||||||
|
hidden_act (str): Activation function used in the Transformer encoder/decoder
|
||||||
|
cell. Default: "relu".
|
||||||
|
loss_scale_mode (str): Loss scale mode. Default: "dynamic".
|
||||||
|
init_loss_scale (int): Initialized loss scale.
|
||||||
|
loss_scale_factor (int): Loss scale factor.
|
||||||
|
scale_window (int): Window size of loss scale.
|
||||||
|
beam_width (int): Beam width for beam search in inferring. Default: 4.
|
||||||
|
length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
|
||||||
|
label_smoothing (float): Label smoothing setting. Default: 0.1.
|
||||||
|
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
|
||||||
|
dataset. Default: True.
|
||||||
|
save_graphs (bool): Whether to save graphs, please set to True if mindinsight
|
||||||
|
is wanted.
|
||||||
|
dtype (mstype): Data type of the input. Default: mstype.float32.
|
||||||
|
max_decode_length (int): Max decode length for inferring. Default: 64.
|
||||||
|
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
|
||||||
|
attention_dropout_prob (float): The dropout probability for
|
||||||
|
Multi-head Self-Attention. Default: 0.1.
|
||||||
|
max_position_embeddings (int): Maximum length of sequences used in this
|
||||||
|
model. Default: 512.
|
||||||
|
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
random_seed=74,
|
||||||
|
batch_size=64, epochs=1,
|
||||||
|
dataset_sink_mode=True, dataset_sink_step=1,
|
||||||
|
lr_scheduler="", optimizer="adam",
|
||||||
|
lr=1e-4, min_lr=1e-6,
|
||||||
|
decay_steps=10000, poly_lr_scheduler_power=1,
|
||||||
|
decay_start_step=-1, warmup_steps=2000,
|
||||||
|
pre_train_dataset: str = None,
|
||||||
|
fine_tune_dataset: str = None,
|
||||||
|
test_dataset: str = None,
|
||||||
|
valid_dataset: str = None,
|
||||||
|
ckpt_path: str = None,
|
||||||
|
save_ckpt_steps=2000,
|
||||||
|
ckpt_prefix="CKPT",
|
||||||
|
existed_ckpt="",
|
||||||
|
keep_ckpt_max=20,
|
||||||
|
seq_length=128,
|
||||||
|
vocab_size=46192,
|
||||||
|
hidden_size=512,
|
||||||
|
num_hidden_layers=6,
|
||||||
|
ngram=2,
|
||||||
|
accumulation_steps=1,
|
||||||
|
disable_ngram_loss=False,
|
||||||
|
num_attention_heads=8,
|
||||||
|
intermediate_size=4096,
|
||||||
|
hidden_act="relu",
|
||||||
|
hidden_dropout_prob=0.1,
|
||||||
|
attention_dropout_prob=0.1,
|
||||||
|
max_position_embeddings=64,
|
||||||
|
initializer_range=0.02,
|
||||||
|
loss_scale_mode="dynamic",
|
||||||
|
init_loss_scale=2 ** 10,
|
||||||
|
loss_scale_factor=2, scale_window=2000,
|
||||||
|
beam_width=5,
|
||||||
|
length_penalty_weight=1.0,
|
||||||
|
label_smoothing=0.1,
|
||||||
|
input_mask_from_dataset=True,
|
||||||
|
save_graphs=False,
|
||||||
|
dtype=mstype.float32,
|
||||||
|
max_decode_length=64):
|
||||||
|
|
||||||
|
self.save_graphs = save_graphs
|
||||||
|
self.random_seed = random_seed
|
||||||
|
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
|
||||||
|
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
|
||||||
|
self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
|
||||||
|
self.test_dataset = get_source_list(test_dataset) # type: List[str]
|
||||||
|
|
||||||
|
if not isinstance(epochs, int) and epochs < 0:
|
||||||
|
raise ValueError("`epoch` must be type of int.")
|
||||||
|
|
||||||
|
self.epochs = epochs
|
||||||
|
self.dataset_sink_mode = dataset_sink_mode
|
||||||
|
self.dataset_sink_step = dataset_sink_step
|
||||||
|
|
||||||
|
self.ckpt_path = ckpt_path
|
||||||
|
self.keep_ckpt_max = keep_ckpt_max
|
||||||
|
self.save_ckpt_steps = save_ckpt_steps
|
||||||
|
self.ckpt_prefix = ckpt_prefix
|
||||||
|
self.existed_ckpt = existed_ckpt
|
||||||
|
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.seq_length = seq_length
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.ngram = ngram
|
||||||
|
self.accumulation_steps = accumulation_steps
|
||||||
|
self.disable_ngram_loss = disable_ngram_loss
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.hidden_dropout_prob = hidden_dropout_prob
|
||||||
|
self.attention_dropout_prob = attention_dropout_prob
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
|
||||||
|
self.beam_width = beam_width
|
||||||
|
self.length_penalty_weight = length_penalty_weight
|
||||||
|
self.max_decode_length = max_decode_length
|
||||||
|
self.input_mask_from_dataset = input_mask_from_dataset
|
||||||
|
self.compute_type = mstype.float32
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
self.loss_scale_mode = loss_scale_mode
|
||||||
|
self.scale_window = scale_window
|
||||||
|
self.loss_scale_factor = loss_scale_factor
|
||||||
|
self.init_loss_scale = init_loss_scale
|
||||||
|
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.lr = lr
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
self.min_lr = min_lr
|
||||||
|
self.poly_lr_scheduler_power = poly_lr_scheduler_power
|
||||||
|
self.decay_steps = decay_steps
|
||||||
|
self.decay_start_step = decay_start_step
|
||||||
|
self.warmup_steps = warmup_steps
|
||||||
|
|
||||||
|
self.train_url = ""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, json_object: dict):
|
||||||
|
"""Constructs a `TransformerConfig` from a Python dictionary of parameters."""
|
||||||
|
_params = {}
|
||||||
|
for node in PARAM_NODES:
|
||||||
|
for key in json_object[node]:
|
||||||
|
_params[key] = json_object[node][key]
|
||||||
|
return cls(**_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_json_file(cls, json_file):
|
||||||
|
"""Constructs a `TransformerConfig` from a json file of parameters."""
|
||||||
|
with open(json_file, "r") as reader:
|
||||||
|
return cls.from_dict(json.load(reader))
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
"""Serializes this instance to a Python dictionary."""
|
||||||
|
output = copy.deepcopy(self.__dict__)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def to_json_string(self):
|
||||||
|
"""Serializes this instance to a JSON string."""
|
||||||
|
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
@ -0,0 +1,59 @@
|
|||||||
|
{
|
||||||
|
"dataset_config": {
|
||||||
|
"epochs": 5,
|
||||||
|
"batch_size": 1,
|
||||||
|
"pre_train_dataset": "",
|
||||||
|
"fine_tune_dataset": "../cnndm_data_prophetnet/dataset_hugging_face_tokenized/train",
|
||||||
|
"test_dataset": "",
|
||||||
|
"valid_dataset": "",
|
||||||
|
"dataset_sink_mode": false,
|
||||||
|
"dataset_sink_step": 100
|
||||||
|
},
|
||||||
|
"model_config": {
|
||||||
|
"random_seed": 1,
|
||||||
|
"save_graphs": false,
|
||||||
|
"seq_length": 512,
|
||||||
|
"vocab_size": 30522,
|
||||||
|
"hidden_size": 512,
|
||||||
|
"num_hidden_layers": 3,
|
||||||
|
"ngram": 2,
|
||||||
|
"accumulation_steps": 1,
|
||||||
|
"disable_ngram_loss": false,
|
||||||
|
"num_attention_heads": 8,
|
||||||
|
"intermediate_size": 2048,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"attention_dropout_prob": 0.1,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"label_smoothing": 0.1,
|
||||||
|
"beam_width": 5,
|
||||||
|
"length_penalty_weight": 1.0,
|
||||||
|
"max_decode_length": 64,
|
||||||
|
"input_mask_from_dataset": true
|
||||||
|
},
|
||||||
|
"loss_scale_config": {
|
||||||
|
"loss_scale_mode":"static",
|
||||||
|
"init_loss_scale": 1,
|
||||||
|
"loss_scale_factor": 2,
|
||||||
|
"scale_window": 200
|
||||||
|
},
|
||||||
|
"learn_rate_config": {
|
||||||
|
"optimizer": "adam",
|
||||||
|
"lr": 1e-4,
|
||||||
|
"lr_scheduler": "isr",
|
||||||
|
"poly_lr_scheduler_power": 0.5,
|
||||||
|
"decay_steps": 10000,
|
||||||
|
"decay_start_step": 1000,
|
||||||
|
"warmup_steps": 1000,
|
||||||
|
"min_lr": 1e-7
|
||||||
|
},
|
||||||
|
"checkpoint_options": {
|
||||||
|
"existed_ckpt": "",
|
||||||
|
"save_ckpt_steps": 20000,
|
||||||
|
"keep_ckpt_max": 50,
|
||||||
|
"ckpt_prefix": "ckpt",
|
||||||
|
"ckpt_path": "checkpoints"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,58 @@
|
|||||||
|
{
|
||||||
|
"dataset_config": {
|
||||||
|
"epochs": 2,
|
||||||
|
"batch_size": 1,
|
||||||
|
"pre_train_dataset": "../news_crawl/dataset/tf_small_pretrain",
|
||||||
|
"fine_tune_dataset": "",
|
||||||
|
"test_dataset": "",
|
||||||
|
"valid_dataset": "",
|
||||||
|
"dataset_sink_mode": false,
|
||||||
|
"dataset_sink_step": 100
|
||||||
|
},
|
||||||
|
"model_config": {
|
||||||
|
"random_seed": 100,
|
||||||
|
"save_graphs": false,
|
||||||
|
"seq_length": 128,
|
||||||
|
"vocab_size": 44000,
|
||||||
|
"hidden_size": 768,
|
||||||
|
"num_hidden_layers": 3,
|
||||||
|
"ngram": 2,
|
||||||
|
"disable_ngram_loss": false,
|
||||||
|
"num_attention_heads": 12,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"hidden_act": "relu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"attention_dropout_prob": 0.1,
|
||||||
|
"max_position_embeddings": 64,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"label_smoothing": 0.1,
|
||||||
|
"beam_width": 4,
|
||||||
|
"length_penalty_weight": 1.0,
|
||||||
|
"max_decode_length": 64,
|
||||||
|
"input_mask_from_dataset": true
|
||||||
|
},
|
||||||
|
"loss_scale_config": {
|
||||||
|
"loss_scale_mode":"static",
|
||||||
|
"init_loss_scale": 32,
|
||||||
|
"loss_scale_factor": 2,
|
||||||
|
"scale_window": 200
|
||||||
|
},
|
||||||
|
"learn_rate_config": {
|
||||||
|
"optimizer": "adam",
|
||||||
|
"lr": 1e-4,
|
||||||
|
"lr_scheduler": "poly",
|
||||||
|
"poly_lr_scheduler_power": 0.5,
|
||||||
|
"decay_steps": 10000,
|
||||||
|
"decay_start_step": 12000,
|
||||||
|
"warmup_steps": 4000,
|
||||||
|
"min_lr": 1e-6
|
||||||
|
},
|
||||||
|
"checkpoint_options": {
|
||||||
|
"existed_ckpt": "/home/yanglinfeng/ProphetNet/training_result/checkpoints/ckpt_1_0.ckpt",
|
||||||
|
"save_ckpt_steps": 10,
|
||||||
|
"keep_ckpt_max": 50,
|
||||||
|
"ckpt_prefix": "ckpt",
|
||||||
|
"ckpt_path": "checkpoints"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,57 @@
|
|||||||
|
{
|
||||||
|
"dataset_config": {
|
||||||
|
"epochs": 2,
|
||||||
|
"batch_size": 1,
|
||||||
|
"pre_train_dataset": "",
|
||||||
|
"fine_tune_dataset": "",
|
||||||
|
"test_dataset": "../cnndm_data_prophetnet/dataset_hugging_face_tokenized",
|
||||||
|
"valid_dataset": "",
|
||||||
|
"dataset_sink_mode": false,
|
||||||
|
"dataset_sink_step": 100
|
||||||
|
},
|
||||||
|
"model_config": {
|
||||||
|
"random_seed": 100,
|
||||||
|
"save_graphs": false,
|
||||||
|
"seq_length": 512,
|
||||||
|
"vocab_size": 30522,
|
||||||
|
"hidden_size": 512,
|
||||||
|
"num_hidden_layers": 3,
|
||||||
|
"ngram": 2,
|
||||||
|
"disable_ngram_loss": false,
|
||||||
|
"num_attention_heads": 8,
|
||||||
|
"intermediate_size": 2048,
|
||||||
|
"hidden_act": "gelu",
|
||||||
|
"hidden_dropout_prob": 0.1,
|
||||||
|
"attention_dropout_prob": 0.1,
|
||||||
|
"max_position_embeddings": 512,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"label_smoothing": 0.1,
|
||||||
|
"beam_width": 5,
|
||||||
|
"length_penalty_weight": 1.2,
|
||||||
|
"max_decode_length": 110,
|
||||||
|
"input_mask_from_dataset": true
|
||||||
|
},
|
||||||
|
"loss_scale_config": {
|
||||||
|
"loss_scale_mode":"static",
|
||||||
|
"init_loss_scale": 32,
|
||||||
|
"loss_scale_factor": 2,
|
||||||
|
"scale_window": 200
|
||||||
|
},
|
||||||
|
"learn_rate_config": {
|
||||||
|
"optimizer": "adam",
|
||||||
|
"lr": 1e-4,
|
||||||
|
"lr_scheduler": "poly",
|
||||||
|
"poly_lr_scheduler_power": 0.5,
|
||||||
|
"decay_steps": 10000,
|
||||||
|
"decay_start_step": 12000,
|
||||||
|
"warmup_steps": 4000,
|
||||||
|
"min_lr": 1e-6
|
||||||
|
},
|
||||||
|
"checkpoint_options": {
|
||||||
|
"existed_ckpt": "../training_weight/ckpt-1_20000.ckpt",
|
||||||
|
"save_ckpt_steps": 500,
|
||||||
|
"keep_ckpt_max": 50,
|
||||||
|
"ckpt_prefix": "ckpt",
|
||||||
|
"ckpt_path": "checkpoints"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,77 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Evaluation api."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore import context
|
||||||
|
|
||||||
|
from config import TransformerConfig
|
||||||
|
from src.transformer import infer, infer_ppl
|
||||||
|
from src.utils import Dictionary
|
||||||
|
from src.utils import get_score
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Evaluation MASS.')
|
||||||
|
parser.add_argument("--config", type=str, required=True,
|
||||||
|
help="Model config json file path.")
|
||||||
|
parser.add_argument("--vocab", type=str, required=True,
|
||||||
|
help="Vocabulary to use.")
|
||||||
|
parser.add_argument("--output", type=str, required=True,
|
||||||
|
help="Result file path.")
|
||||||
|
parser.add_argument("--metric", type=str, default='rouge',
|
||||||
|
help='Set eval method.')
|
||||||
|
parser.add_argument("--platform", type=str, required=True,
|
||||||
|
help="model working platform.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_config(config):
|
||||||
|
config = TransformerConfig.from_json_file(config)
|
||||||
|
config.compute_type = mstype.float32
|
||||||
|
config.dtype = mstype.float32
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
if args.vocab.endswith("bin"):
|
||||||
|
vocab = Dictionary.load_from_persisted_dict(args.vocab)
|
||||||
|
else:
|
||||||
|
vocab = Dictionary.load_from_text([args.vocab])
|
||||||
|
_config = get_config(args.config)
|
||||||
|
|
||||||
|
device_id = os.getenv('DEVICE_ID', None)
|
||||||
|
if device_id is None:
|
||||||
|
device_id = 0
|
||||||
|
device_id = int(device_id)
|
||||||
|
context.set_context(
|
||||||
|
#mode=context.GRAPH_MODE,
|
||||||
|
mode=context.PYNATIVE_MODE,
|
||||||
|
device_target=args.platform,
|
||||||
|
reserve_class_name_in_scope=False,
|
||||||
|
device_id=device_id)
|
||||||
|
|
||||||
|
if args.metric == 'rouge':
|
||||||
|
result = infer(_config)
|
||||||
|
else:
|
||||||
|
result = infer_ppl(_config)
|
||||||
|
|
||||||
|
with open(args.output, "wb") as f:
|
||||||
|
pickle.dump(result, f, 1)
|
||||||
|
|
||||||
|
# get score by given metric
|
||||||
|
score = get_score(result, vocab, metric=args.metric)
|
||||||
|
print(score)
|
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Generate Gigaword dataset."""
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from src.dataset import BiLingualDataLoader
|
||||||
|
from src.language_model import NoiseChannelLanguageModel
|
||||||
|
from src.utils import Dictionary
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Create Gigaword fine-tune Dataset.')
|
||||||
|
parser.add_argument("--train_src", type=str, default="", required=False,
|
||||||
|
help="train dataset source file path.")
|
||||||
|
parser.add_argument("--train_ref", type=str, default="", required=False,
|
||||||
|
help="train dataset reference file path.")
|
||||||
|
parser.add_argument("--test_src", type=str, default="", required=False,
|
||||||
|
help="test dataset source file path.")
|
||||||
|
parser.add_argument("--test_ref", type=str, default="", required=False,
|
||||||
|
help="test dataset reference file path.")
|
||||||
|
parser.add_argument("--noise_prob", type=float, default=0., required=False,
|
||||||
|
help="add noise prob.")
|
||||||
|
parser.add_argument("--existed_vocab", type=str, default="", required=False,
|
||||||
|
help="existed vocab path.")
|
||||||
|
parser.add_argument("--max_len", type=int, default=64, required=False,
|
||||||
|
help="max length of sentences.")
|
||||||
|
parser.add_argument("--output_folder", type=str, default="", required=True,
|
||||||
|
help="dataset output path.")
|
||||||
|
parser.add_argument("--format", type=str, default="tfrecord", required=False,
|
||||||
|
help="dataset format.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
vocab = Dictionary.load_from_persisted_dict(args.existed_vocab)
|
||||||
|
|
||||||
|
if args.train_src and args.train_ref:
|
||||||
|
train = BiLingualDataLoader(
|
||||||
|
src_filepath=args.train_src,
|
||||||
|
tgt_filepath=args.train_ref,
|
||||||
|
src_dict=vocab, tgt_dict=vocab,
|
||||||
|
src_lang="en", tgt_lang="en",
|
||||||
|
language_model=NoiseChannelLanguageModel(add_noise_prob=args.noise_prob),
|
||||||
|
max_sen_len=args.max_len
|
||||||
|
)
|
||||||
|
if "tf" in args.format.lower():
|
||||||
|
train.write_to_tfrecord(
|
||||||
|
path=os.path.join(args.output_folder, "gigaword_train_dataset.tfrecord")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
train.write_to_mindrecord(
|
||||||
|
path=os.path.join(args.output_folder, "gigaword_train_dataset.mindrecord")
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.test_src and args.test_ref:
|
||||||
|
test = BiLingualDataLoader(
|
||||||
|
src_filepath=args.test_src,
|
||||||
|
tgt_filepath=args.test_ref,
|
||||||
|
src_dict=vocab, tgt_dict=vocab,
|
||||||
|
src_lang="en", tgt_lang="en",
|
||||||
|
language_model=NoiseChannelLanguageModel(add_noise_prob=0),
|
||||||
|
max_sen_len=args.max_len
|
||||||
|
)
|
||||||
|
if "tf" in args.format.lower():
|
||||||
|
test.write_to_tfrecord(
|
||||||
|
path=os.path.join(args.output_folder, "gigaword_test_dataset.tfrecord")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
test.write_to_mindrecord(
|
||||||
|
path=os.path.join(args.output_folder, "gigaword_test_dataset.mindrecord")
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" | Vocabulary size: {vocab.size}.")
|
@ -0,0 +1,209 @@
|
|||||||
|
python tokenize_corpus.py --corpus_folder /{path}/corpus --output_folder /{path}/tokenized_corpus --tokenizer nltk --pool_size 16
|
||||||
|
cd tokenized_corpus/
|
||||||
|
|
||||||
|
# build bpe codes
|
||||||
|
cat *.txt | subword-nmt learn-bpe -s 46000 -o all.bpe.codes
|
||||||
|
|
||||||
|
# build bpe dict
|
||||||
|
"subword-nmt get-vocab -i tokenized.txt -o vocab_en.dict.bin"
|
||||||
|
|
||||||
|
# apply bpe encoding
|
||||||
|
python apply_bpe_encoding.py --codes ~/Mindspore/mindspore/model_zoo/official/nlp/mass/tokenized_corpus/all.bpe.codes \
|
||||||
|
--src_folder ~/Mindspore/mindspore/model_zoo/official/nlp/mass/tokenized_corpus/ \
|
||||||
|
--output_folder ~/Mindspore/mindspore/model_zoo/official/nlp/mass/tokenized_corpus/bpe \
|
||||||
|
--vocab_path ~/Mindspore/mindspore/model_zoo/official/nlp/mass/tokenized_corpus/vocab_en.dict.bin \
|
||||||
|
--processes 32
|
||||||
|
|
||||||
|
# build dataset news crawl
|
||||||
|
python news_crawl.py --src_folder ./news_crawl \
|
||||||
|
--dict_folder ./news_crawl \
|
||||||
|
--existed_vocab ./tokenized_corpus/vocab_en.dict.bin \
|
||||||
|
--mask_ratio 0.5 \
|
||||||
|
--output_folder ./news_crawl/dataset/tf_small_pretrain \
|
||||||
|
--max_len 128 \
|
||||||
|
--processes 32 \
|
||||||
|
--ngram 2
|
||||||
|
|
||||||
|
# build dataset cnndm
|
||||||
|
python cnn_dm.py --test_src ./cnndm_data_prophetnet/prophetnet_tokenized/test.src.txt --test_ref ./cnndm_data_prophetnet/prophetnet_tokenized/test.tgt.txt --existed_vocab ./cnndm_data_prophetnet/cnndm_torch_prophetnet_30522.bin --noise_prob 0.0 --output_folder ./cnndm_data_prophetnet/dataset_hugging_face_tokenized/ --max_len 512
|
||||||
|
|
||||||
|
|
||||||
|
# train
|
||||||
|
bash run_gpu.sh --task t --device_num 1 --device_id 3 --config ./config/config.json
|
||||||
|
|
||||||
|
# inference
|
||||||
|
bash run_gpu.sh --task i \
|
||||||
|
--device_num 1 \
|
||||||
|
--device_id 3 \
|
||||||
|
--config ./config/test.json \
|
||||||
|
--output output \
|
||||||
|
--metric rouge \
|
||||||
|
--vocab ./cnndm_data_prophetnet/cnndm_torch_prophetnet_30522.bin
|
||||||
|
|
||||||
|
# pytorch model structure
|
||||||
|
NgramTransformerProphetModel(
|
||||||
|
(encoder): TransformerEncoder(
|
||||||
|
(embed_tokens): Embedding(30522, 512, padding_idx=0)
|
||||||
|
(embed_positions): LearnedPositionalEmbedding(513, 512, padding_idx=0)
|
||||||
|
(layers): ModuleList(
|
||||||
|
(0): TransformerEncoderLayer(
|
||||||
|
(self_attn): MultiheadAttention(
|
||||||
|
(k_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(v_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(q_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(fc1): Linear(in_features=512, out_features=2048, bias=True)
|
||||||
|
(fc2): Linear(in_features=2048, out_features=512, bias=True)
|
||||||
|
(final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
)
|
||||||
|
(1): TransformerEncoderLayer(
|
||||||
|
(self_attn): MultiheadAttention(
|
||||||
|
(k_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(v_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(q_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(fc1): Linear(in_features=512, out_features=2048, bias=True)
|
||||||
|
(fc2): Linear(in_features=2048, out_features=512, bias=True)
|
||||||
|
(final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
)
|
||||||
|
(2): TransformerEncoderLayer(
|
||||||
|
(self_attn): MultiheadAttention(
|
||||||
|
(k_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(v_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(q_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(fc1): Linear(in_features=512, out_features=2048, bias=True)
|
||||||
|
(fc2): Linear(in_features=2048, out_features=512, bias=True)
|
||||||
|
(final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(emb_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
)
|
||||||
|
(decoder): NgramTransformerDecoder(
|
||||||
|
(embed_tokens): Embedding(30522, 512, padding_idx=0)
|
||||||
|
(embed_positions): LearnedPositionalEmbedding(514, 512, padding_idx=0)
|
||||||
|
(ngram_input_embed): Embedding(2, 512)
|
||||||
|
(layers): ModuleList(
|
||||||
|
(0): NgramTransformerDecoderLayer(
|
||||||
|
(ngram_self_attn): NgramMultiheadAttention(
|
||||||
|
(relative_linear): Linear(in_features=512, out_features=256, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(encoder_attn): MultiheadAttention(
|
||||||
|
(k_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(v_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(q_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(fc1): Linear(in_features=512, out_features=2048, bias=True)
|
||||||
|
(fc2): Linear(in_features=2048, out_features=512, bias=True)
|
||||||
|
(final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
)
|
||||||
|
(1): NgramTransformerDecoderLayer(
|
||||||
|
(ngram_self_attn): NgramMultiheadAttention(
|
||||||
|
(relative_linear): Linear(in_features=512, out_features=256, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(encoder_attn): MultiheadAttention(
|
||||||
|
(k_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(v_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(q_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(fc1): Linear(in_features=512, out_features=2048, bias=True)
|
||||||
|
(fc2): Linear(in_features=2048, out_features=512, bias=True)
|
||||||
|
(final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
)
|
||||||
|
(2): NgramTransformerDecoderLayer(
|
||||||
|
(ngram_self_attn): NgramMultiheadAttention(
|
||||||
|
(relative_linear): Linear(in_features=512, out_features=256, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(encoder_attn): MultiheadAttention(
|
||||||
|
(k_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(v_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(q_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
(out_proj): Linear(in_features=512, out_features=512, bias=True)
|
||||||
|
)
|
||||||
|
(encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
(fc1): Linear(in_features=512, out_features=2048, bias=True)
|
||||||
|
(fc2): Linear(in_features=2048, out_features=512, bias=True)
|
||||||
|
(final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
(emb_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
data example:
|
||||||
|
src_tokens
|
||||||
|
tensor([[ 1996, 11555, 18172, 7042, 2055, 1037, 18147, 5913, 3756, 6982,
|
||||||
|
1999, 1996, 4120, 1012, 2007, 1996, 4022, 2000, 2022, 3621,
|
||||||
|
2062, 4795, 1010, 2021, 2074, 2004, 26102, 1010, 1996, 7726,
|
||||||
|
3212, 2038, 2042, 27696, 1996, 6745, 2804, 2000, 2049, 4170,
|
||||||
|
1011, 1037, 8235, 4408, 28653, 2630, 6982, 1012, 11216, 1997,
|
||||||
|
1996, 27143, 1011, 2550, 21905, 2442, 2031, 2245, 2008, 1996,
|
||||||
|
13576, 8703, 2052, 2191, 1996, 7477, 12586, 1999, 2007, 1996,
|
||||||
|
2784, 5380, 1997, 1996, 2152, 11915, 1012, 17186, 2091, 2005,
|
||||||
|
2678, 1012, 3239, 1011, 9105, 1024, 7726, 3212, 9058, 2020,
|
||||||
|
4760, 2125, 2037, 4408, 28653, 12622, 2006, 2110, 2547, 1012,
|
||||||
|
18783, 1024, 7726, 3212, 3738, 3233, 2006, 2327, 1997, 1996,
|
||||||
|
8254, 2050, 1021, 6982, 2328, 27143, 1012, 2021, 2009, 1005,
|
||||||
|
1055, 2524, 2000, 2903, 2008, 1996, 4099, 2180, 1005, 1056,
|
||||||
|
2156, 2023, 2028, 2746, 2007, 1996, 6120, 2437, 2009, 3233,
|
||||||
|
2041, 2066, 1037, 14699, 7639, 2114, 1996, 2300, 1005, 1055,
|
||||||
|
3302, 1012, 1996, 3212, 2001, 4760, 2125, 1996, 3239, 1011,
|
||||||
|
9105, 4325, 1010, 2029, 2003, 2105, 1996, 2946, 1997, 1037,
|
||||||
|
15437, 1010, 2006, 4238, 2110, 2547, 7483, 1012, 3212, 4584,
|
||||||
|
1010, 2738, 4603, 2135, 5102, 1999, 5810, 2601, 11408, 4102,
|
||||||
|
2000, 2037, 28190, 2911, 1010, 3427, 2004, 1996, 8254, 2050,
|
||||||
|
1011, 1021, 1010, 6055, 2007, 3424, 1011, 2911, 10815, 1010,
|
||||||
|
2001, 3390, 2012, 24112, 2099, 17532, 1010, 2379, 1996, 6143,
|
||||||
|
11195, 1997, 7570, 10867, 17040, 1012, 2048, 2047, 7726, 1011,
|
||||||
|
2328, 1043, 16102, 4313, 4942, 2015, 1998, 2048, 13671, 25215,
|
||||||
|
11890, 27528, 2102, 2020, 2036, 5359, 2000, 1996, 3212, 1012,
|
||||||
|
8235, 2630, 1024, 4238, 1005, 1055, 4397, 3390, 1043, 16102,
|
||||||
|
4313, 6982, 5829, 1999, 2392, 1997, 1037, 4049, 1999, 1996,
|
||||||
|
2670, 3417, 1997, 24112, 2099, 17532, 1999, 1996, 4723, 6084,
|
||||||
|
1012, 19194, 1024, 1996, 12622, 3233, 2041, 2066, 1037, 14699,
|
||||||
|
1011, 7639, 2114, 1996, 3302, 1997, 1996, 2712, 1012, 3212,
|
||||||
|
2708, 4373, 5902, 5292, 28065, 14511, 4430, 2360, 13380, 2072,
|
||||||
|
2001, 9339, 2006, 7726, 2547, 2004, 3038, 2008, 1996, 3842,
|
||||||
|
2442, 10295, 1996, 1005, 14751, 2974, 1998, 2327, 1011, 3694,
|
||||||
|
4128, 2000, 4047, 2049, 6645, 1012, 1005, 1043, 16102, 4313,
|
||||||
|
2465, 12622, 2064, 2543, 10815, 1998, 18544, 2012, 1996, 2168,
|
||||||
|
2051, 1010, 1998, 2064, 5452, 1999, 1996, 4723, 6084, 1005,
|
||||||
|
1055, 8467, 5380, 1012, 4238, 2038, 4912, 2000, 12200, 2049,
|
||||||
|
2250, 3639, 1998, 3987, 9859, 1010, 3038, 2151, 2825, 2925,
|
||||||
|
4491, 2006, 2009, 2052, 2272, 2013, 1996, 2250, 1998, 2712,
|
||||||
|
1012, 1996, 2406, 2085, 4447, 2000, 2022, 1005, 2969, 7182,
|
||||||
|
1005, 1999, 3408, 1997, 17731, 3941, 2000, 3113, 2049, 2510,
|
||||||
|
3791, 1012, 14430, 1024, 1996, 7726, 6982, 1005, 1055, 2453,
|
||||||
|
2022, 2062, 9252, 2084, 1996, 11555, 1005, 21864, 15952, 3756,
|
||||||
|
6982, 1010, 15885, 1010, 2021, 2027, 2024, 8053, 14224, 11401,
|
||||||
|
1012, 102]], device='cuda:0')
|
||||||
|
prev_output_tokens
|
||||||
|
tensor([[ 102, 7726, 2110, 2547, 3662, 8333, 1997, 1996, 2047, 3719,
|
||||||
|
1011, 1037, 8254, 2050, 1021, 6982, 1010, 2048, 1043, 16102,
|
||||||
|
4313, 4942, 2015, 1998, 1037, 3940, 1997, 25215, 11890, 27528,
|
||||||
|
2102, 1012, 2, 3212, 4584, 2360, 2008, 1996, 4170, 2442,
|
||||||
|
10295, 1005, 1996, 14751, 2974, 1005, 2000, 4047, 2049, 6645,
|
||||||
|
1012]], device='cuda:0')
|
||||||
|
target_tokens:
|
||||||
|
tensor([[ 7726, 2110, 2547, 3662, 8333, 1997, 1996, 2047, 3719, 1011,
|
||||||
|
1037, 8254, 2050, 1021, 6982, 1010, 2048, 1043, 16102, 4313,
|
||||||
|
4942, 2015, 1998, 1037, 3940, 1997, 25215, 11890, 27528, 2102,
|
||||||
|
1012, 2, 3212, 4584, 2360, 2008, 1996, 4170, 2442, 10295,
|
||||||
|
1005, 1996, 14751, 2974, 1005, 2000, 4047, 2049, 6645, 1012,
|
||||||
|
102]], device='cuda:0')
|
@ -0,0 +1,61 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Generate News Crawl corpus dataset."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from src.utils import Dictionary
|
||||||
|
from src.utils.preprocess import create_pre_training_dataset
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Create News Crawl Pre-Training Dataset.')
|
||||||
|
parser.add_argument("--src_folder", type=str, default="", required=True,
|
||||||
|
help="Raw corpus folder.")
|
||||||
|
parser.add_argument("--existed_vocab", type=str, default="", required=True,
|
||||||
|
help="Existed vocab path.")
|
||||||
|
parser.add_argument("--mask_ratio", type=float, default=0.4, required=True,
|
||||||
|
help="Mask ratio.")
|
||||||
|
parser.add_argument("--output_folder", type=str, default="", required=True,
|
||||||
|
help="Dataset output path.")
|
||||||
|
parser.add_argument("--max_len", type=int, default=32, required=False,
|
||||||
|
help="Max length of sentences.")
|
||||||
|
parser.add_argument("--ngram", type=int, default=3, required=True,
|
||||||
|
help="Number of tokens to predict ahead.")
|
||||||
|
parser.add_argument("--suffix", type=str, default="", required=False,
|
||||||
|
help="Add suffix to output file.")
|
||||||
|
parser.add_argument("--processes", type=int, default=2, required=False,
|
||||||
|
help="Size of processes pool.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args, _ = parser.parse_known_args()
|
||||||
|
if not (args.src_folder and args.output_folder):
|
||||||
|
raise ValueError("Please enter required params.")
|
||||||
|
|
||||||
|
if not args.existed_vocab:
|
||||||
|
raise ValueError("`--existed_vocab` is required.")
|
||||||
|
|
||||||
|
vocab = Dictionary.load_from_persisted_dict(args.existed_vocab)
|
||||||
|
|
||||||
|
create_pre_training_dataset(
|
||||||
|
folder_path=args.src_folder,
|
||||||
|
output_folder_path=args.output_folder,
|
||||||
|
vocabulary=vocab,
|
||||||
|
prefix="news.20", suffix=args.suffix,
|
||||||
|
mask_ratio=args.mask_ratio,
|
||||||
|
ngram=args.ngram,
|
||||||
|
min_sen_len=10,
|
||||||
|
max_sen_len=args.max_len,
|
||||||
|
dataset_type="tfrecord",
|
||||||
|
cores=args.processes
|
||||||
|
)
|
||||||
|
print(f" | Vocabulary size: {vocab.size}.")
|
@ -0,0 +1,20 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
src_folder_path=$1 # source text folder path.
|
||||||
|
|
||||||
|
cd $src_folder_path || exit
|
||||||
|
cat *.txt | subword-nmt learn-bpe -s 46000 -o all.bpe.codes
|
@ -0,0 +1,179 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
options=`getopt -u -o ht:n:i:j:c:o:v:m: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab:,metric: -- "$@"`
|
||||||
|
eval set -- "$options"
|
||||||
|
echo $options
|
||||||
|
|
||||||
|
echo_help()
|
||||||
|
{
|
||||||
|
echo "Usage:"
|
||||||
|
echo "bash train.sh [-h] [-t t|i] [-n N] [-i N] [-j FILE] [-c FILE] [-o FILE] [-v FILE]"
|
||||||
|
echo "options:"
|
||||||
|
echo " -h --help show usage"
|
||||||
|
echo " -t --task select task, 't' for training and 'i' for inference"
|
||||||
|
echo " -n --device_num training with N devices"
|
||||||
|
echo " -i --device_id training with device i"
|
||||||
|
echo " -j --hccl_json set the rank table file"
|
||||||
|
echo " -c --config set the configuration file"
|
||||||
|
echo " -o --output set the output file of inference"
|
||||||
|
echo " -v --vocab set the vocabulary"
|
||||||
|
echo " -m --metric set the metric"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_hccl_json()
|
||||||
|
{
|
||||||
|
while [ -n "$1" ]
|
||||||
|
do
|
||||||
|
if [[ "$1" == "-j" || "$1" == "--hccl_json" ]]
|
||||||
|
then
|
||||||
|
export RANK_TABLE_FILE=$2
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
shift
|
||||||
|
done
|
||||||
|
}
|
||||||
|
set_device_id()
|
||||||
|
{
|
||||||
|
while [ -n "$1" ]
|
||||||
|
do
|
||||||
|
if [[ "$1" == "-i" || "$1" == "--device_id" ]]
|
||||||
|
then
|
||||||
|
if [[ $2 -ge 0 && $2 -le 7 ]]
|
||||||
|
then
|
||||||
|
export DEVICE_ID=$2
|
||||||
|
fi
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
shift
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
while [ -n "$1" ]
|
||||||
|
do
|
||||||
|
case "$1" in
|
||||||
|
-h|--help)
|
||||||
|
echo_help
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-t|--task)
|
||||||
|
echo "task:"
|
||||||
|
if [ "$2" == "t" ]
|
||||||
|
then
|
||||||
|
task=train
|
||||||
|
elif [ "$2" == "i" ]
|
||||||
|
then
|
||||||
|
task=infer
|
||||||
|
fi
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-n|--device_num)
|
||||||
|
echo "device_num"
|
||||||
|
if [ $2 -eq 1 ]
|
||||||
|
then
|
||||||
|
set_device_id $options
|
||||||
|
elif [ $2 -gt 1 ]
|
||||||
|
then
|
||||||
|
export HCCL_FLAG=1
|
||||||
|
export DEPLOY_MODE=0
|
||||||
|
|
||||||
|
export RANK_SIZE=$2
|
||||||
|
set_hccl_json $options
|
||||||
|
fi
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-i|--device_id)
|
||||||
|
echo "set device id"
|
||||||
|
export DEVICE_ID=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-c|--config)
|
||||||
|
echo "config";
|
||||||
|
configurations=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-o|--output)
|
||||||
|
echo "output";
|
||||||
|
output=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-v|--vocab)
|
||||||
|
echo "vocab";
|
||||||
|
vocab=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-m|--metric)
|
||||||
|
echo "metric";
|
||||||
|
metric=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--)
|
||||||
|
shift
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
file_path=$(cd "$(dirname $0)" || exit; pwd)
|
||||||
|
for((i=0; i < $RANK_SIZE; i++))
|
||||||
|
do
|
||||||
|
if [ $RANK_SIZE -gt 1 ]
|
||||||
|
then
|
||||||
|
echo $RANK_SIZE
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$[i]
|
||||||
|
fi
|
||||||
|
echo "Working on device $i"
|
||||||
|
|
||||||
|
cd $file_path || exit
|
||||||
|
cd ../ || exit
|
||||||
|
|
||||||
|
rm -rf ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
mkdir ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
|
||||||
|
cp train_gradient_accumulation.py ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp train.py ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp eval.py ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp -r src ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp -r config ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp $configurations ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
|
||||||
|
if [ $vocab ]
|
||||||
|
then
|
||||||
|
cp $vocab ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd ./${task}_prophetnet_$DEVICE_ID || exit
|
||||||
|
env > log.log
|
||||||
|
echo $task
|
||||||
|
if [ "$task" == "train" ]
|
||||||
|
then
|
||||||
|
#python train.py --config ${configurations##*/} --platform Ascend >>log.log 2>&1 &
|
||||||
|
python train.py --config ${configurations##*/} --platform Ascend
|
||||||
|
elif [ "$task" == "infer" ]
|
||||||
|
then
|
||||||
|
#python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} --platform Ascend >>log_infer.log 2>&1 &
|
||||||
|
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} --platform Ascend
|
||||||
|
fi
|
||||||
|
cd ../
|
||||||
|
done
|
@ -0,0 +1,162 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
export DEVICE_ID=0
|
||||||
|
export RANK_ID=0
|
||||||
|
export RANK_SIZE=1
|
||||||
|
|
||||||
|
options=`getopt -u -o ht:n:i::o:v:m: -l help,task:,device_num:,device_id:,config:,output:,vocab:,metric: -- "$@"`
|
||||||
|
eval set -- "$options"
|
||||||
|
echo $options
|
||||||
|
|
||||||
|
echo_help()
|
||||||
|
{
|
||||||
|
echo "Usage:"
|
||||||
|
echo "bash train.sh [-h] [-t t|i] [-n N] [-i N] [-j FILE] [-c FILE] [-o FILE] [-v FILE]"
|
||||||
|
echo "options:"
|
||||||
|
echo " -h --help show usage"
|
||||||
|
echo " -t --task select task, 't' for training and 'i' for inference"
|
||||||
|
echo " -n --device_num training with N devices"
|
||||||
|
echo " -i --device_id training with device i"
|
||||||
|
echo " -c --config set the configuration file"
|
||||||
|
echo " -o --output set the output file of inference"
|
||||||
|
echo " -v --vocab set the vocabulary"
|
||||||
|
echo " -m --metric set the metric"
|
||||||
|
}
|
||||||
|
|
||||||
|
set_device_id()
|
||||||
|
{
|
||||||
|
while [ -n "$1" ]
|
||||||
|
do
|
||||||
|
if [[ "$1" == "-i" || "$1" == "--device_id" ]]
|
||||||
|
then
|
||||||
|
if [[ $2 -ge 0 && $2 -le 7 ]]
|
||||||
|
then
|
||||||
|
export DEVICE_ID=$2
|
||||||
|
fi
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
shift
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
while [ -n "$1" ]
|
||||||
|
do
|
||||||
|
case "$1" in
|
||||||
|
-h|--help)
|
||||||
|
echo_help
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-t|--task)
|
||||||
|
echo "task:"
|
||||||
|
if [ "$2" == "t" ]
|
||||||
|
then
|
||||||
|
task=train
|
||||||
|
elif [ "$2" == "i" ]
|
||||||
|
then
|
||||||
|
task=infer
|
||||||
|
fi
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-n|--device_num)
|
||||||
|
echo "device_num"
|
||||||
|
if [ $2 -eq 1 ]
|
||||||
|
then
|
||||||
|
set_device_id $options
|
||||||
|
elif [ $2 -gt 1 ]
|
||||||
|
then
|
||||||
|
export RANK_SIZE=$2
|
||||||
|
fi
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-i|--device_id)
|
||||||
|
echo "set device id"
|
||||||
|
export DEVICE_ID=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-c|--config)
|
||||||
|
echo "config";
|
||||||
|
configurations=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-o|--output)
|
||||||
|
echo "output";
|
||||||
|
output=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-v|--vocab)
|
||||||
|
echo "vocab";
|
||||||
|
vocab=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-m|--metric)
|
||||||
|
echo "metric";
|
||||||
|
metric=$2
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--)
|
||||||
|
shift
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
file_path=$(cd "$(dirname $0)" || exit; pwd)
|
||||||
|
if [ $RANK_SIZE -gt 1 ]
|
||||||
|
then
|
||||||
|
echo "Working on $RANK_SIZE device"
|
||||||
|
fi
|
||||||
|
echo "Working on file ${task}_prophetnet_$DEVICE_ID"
|
||||||
|
|
||||||
|
cd $file_path || exit
|
||||||
|
cd ../ || exit
|
||||||
|
|
||||||
|
rm -rf ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
mkdir ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
|
||||||
|
cp train_gradient_accumulation.py ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp train.py ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp eval.py ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp -r src ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp -r config ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
cp $configurations ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
|
||||||
|
if [ $vocab ]
|
||||||
|
then
|
||||||
|
cp $vocab ./${task}_prophetnet_$DEVICE_ID
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd ./${task}_prophetnet_$DEVICE_ID || exit
|
||||||
|
env > log.log
|
||||||
|
echo $task
|
||||||
|
if [ "$task" == "train" ]
|
||||||
|
then
|
||||||
|
if [ $RANK_SIZE -gt 1 ]
|
||||||
|
then
|
||||||
|
mpirun -n $RANK_SIZE python train.py --config ${configurations##*/} --platform GPU >>log.log 2>&1 &
|
||||||
|
fi
|
||||||
|
#python train.py --config ${configurations##*/} --platform GPU >>log.log 2>&1 &
|
||||||
|
python train.py --config ${configurations##*/} --platform GPU
|
||||||
|
elif [ "$task" == "infer" ]
|
||||||
|
then
|
||||||
|
#python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} --platform GPU >>log_infer.log 2>&1 &
|
||||||
|
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} --metric ${metric} --platform GPU
|
||||||
|
fi
|
||||||
|
cd ../
|
||||||
|
|
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Source of mass model."""
|
||||||
|
from .dataset import load_dataset
|
||||||
|
from .dataset import bi_data_loader
|
||||||
|
from .dataset import mono_data_loader
|
||||||
|
from .transformer import TransformerDecoder
|
||||||
|
from .transformer import TransformerEncoder
|
||||||
|
from .transformer import Transformer
|
||||||
|
from .transformer import TransformerNetworkWithLoss
|
||||||
|
from .transformer import LabelSmoothedCrossEntropyCriterion
|
||||||
|
from .transformer import TransformerTrainOneStepWithLossScaleCell
|
||||||
|
from .transformer import TransformerTraining
|
||||||
|
from .transformer import infer
|
||||||
|
from .language_model import LooseMaskedLanguageModel
|
||||||
|
from .language_model import MaskedLanguageModel
|
||||||
|
from .language_model import NoiseChannelLanguageModel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_dataset",
|
||||||
|
"bi_data_loader",
|
||||||
|
"mono_data_loader",
|
||||||
|
"Transformer",
|
||||||
|
"infer",
|
||||||
|
"TransformerTraining",
|
||||||
|
"TransformerNetworkWithLoss",
|
||||||
|
"TransformerTrainOneStepWithLossScaleCell",
|
||||||
|
"LabelSmoothedCrossEntropyCriterion",
|
||||||
|
"LooseMaskedLanguageModel",
|
||||||
|
"MaskedLanguageModel",
|
||||||
|
"NoiseChannelLanguageModel"
|
||||||
|
]
|
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Dataset module."""
|
||||||
|
from .bi_data_loader import BiLingualDataLoader
|
||||||
|
from .mono_data_loader import MonoLingualDataLoader
|
||||||
|
from .load_dataset import load_dataset
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_dataset",
|
||||||
|
"BiLingualDataLoader",
|
||||||
|
"MonoLingualDataLoader"
|
||||||
|
]
|
@ -0,0 +1,111 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Dataset loader to feed into model."""
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset.engine as de
|
||||||
|
import mindspore.dataset.transforms.c_transforms as deC
|
||||||
|
|
||||||
|
|
||||||
|
def _load_dataset(input_files, batch_size, epoch_count=1,
|
||||||
|
sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True):
|
||||||
|
"""
|
||||||
|
Load dataset according to passed in params.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_files (list): Data files.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
epoch_count (int): Epoch count.
|
||||||
|
sink_mode (bool): Whether enable sink mode.
|
||||||
|
sink_step (int): Step to sink.
|
||||||
|
rank_size (int): Rank size.
|
||||||
|
rank_id (int): Rank id.
|
||||||
|
shuffle (bool): Whether shuffle dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset, dataset instance.
|
||||||
|
"""
|
||||||
|
if not input_files:
|
||||||
|
raise FileNotFoundError("Require at least one dataset.")
|
||||||
|
|
||||||
|
if not isinstance(sink_mode, bool):
|
||||||
|
raise ValueError("`sink` must be type of bool.")
|
||||||
|
|
||||||
|
for datafile in input_files:
|
||||||
|
print(f" | Loading {datafile}.")
|
||||||
|
|
||||||
|
ds = de.TFRecordDataset(
|
||||||
|
input_files,
|
||||||
|
columns_list=[
|
||||||
|
"src", "src_padding",
|
||||||
|
"prev_opt", "prev_padding",
|
||||||
|
"target", "tgt_padding"
|
||||||
|
],
|
||||||
|
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
|
||||||
|
shard_equal_rows=True, num_parallel_workers=8)
|
||||||
|
|
||||||
|
ori_dataset_size = ds.get_dataset_size()
|
||||||
|
print(f" | Dataset size: {ori_dataset_size}.")
|
||||||
|
repeat_count = epoch_count
|
||||||
|
|
||||||
|
type_cast_op = deC.TypeCast(mstype.int32)
|
||||||
|
ds = ds.map(input_columns="src", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="src_padding", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="prev_opt", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="prev_padding", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="target", operations=type_cast_op)
|
||||||
|
ds = ds.map(input_columns="tgt_padding", operations=type_cast_op)
|
||||||
|
|
||||||
|
ds = ds.rename(
|
||||||
|
input_columns=["src",
|
||||||
|
"src_padding",
|
||||||
|
"prev_opt",
|
||||||
|
"prev_padding",
|
||||||
|
"target",
|
||||||
|
"tgt_padding"],
|
||||||
|
output_columns=["source_eos_ids",
|
||||||
|
"source_eos_mask",
|
||||||
|
"target_sos_ids",
|
||||||
|
"target_sos_mask",
|
||||||
|
"target_eos_ids",
|
||||||
|
"target_eos_mask"]
|
||||||
|
)
|
||||||
|
|
||||||
|
ds = ds.batch(batch_size, drop_remainder=True)
|
||||||
|
ds = ds.repeat(repeat_count)
|
||||||
|
|
||||||
|
ds.channel_name = 'transformer'
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(data_files: list, batch_size: int, epoch_count: int,
|
||||||
|
sink_mode: bool, sink_step: int = 1, rank_size: int = 1, rank_id: int = 0, shuffle=True):
|
||||||
|
"""
|
||||||
|
Load dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_files (list): Data files.
|
||||||
|
batch_size (int): Batch size.
|
||||||
|
epoch_count (int): Epoch count.
|
||||||
|
sink_mode (bool): Whether enable sink mode.
|
||||||
|
sink_step (int): Step to sink.
|
||||||
|
rank_size (int): Rank size.
|
||||||
|
rank_id (int): Rank id.
|
||||||
|
shuffle (bool): Whether shuffle dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dataset, dataset instance.
|
||||||
|
"""
|
||||||
|
return _load_dataset(data_files, batch_size, epoch_count, sink_mode,
|
||||||
|
sink_step, rank_size, rank_id, shuffle=shuffle)
|
@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Define schema of mindrecord."""
|
||||||
|
|
||||||
|
SCHEMA = {
|
||||||
|
"src": {"type": "int64", "shape": [-1]},
|
||||||
|
"src_padding": {"type": "int64", "shape": [-1]},
|
||||||
|
"prev_opt": {"type": "int64", "shape": [-1]},
|
||||||
|
"prev_padding": {"type": "int64", "shape": [-1]},
|
||||||
|
"target": {"type": "int64", "shape": [-1]},
|
||||||
|
"tgt_padding": {"type": "int64", "shape": [-1]},
|
||||||
|
}
|
@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Language model."""
|
||||||
|
from .noise_channel_language_model import NoiseChannelLanguageModel
|
||||||
|
from .masked_language_model import MaskedLanguageModel
|
||||||
|
from .loose_masked_language_model import LooseMaskedLanguageModel
|
||||||
|
from .mass_language_model import MassLanguageModel
|
||||||
|
from .prophetnet_language_model import ProphetNetLanguageModel, NgramNoiseChannelLanguageModel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LooseMaskedLanguageModel",
|
||||||
|
"MassLanguageModel",
|
||||||
|
"MaskedLanguageModel",
|
||||||
|
"NoiseChannelLanguageModel",
|
||||||
|
"ProphetNetLanguageModel",
|
||||||
|
"NgramNoiseChannelLanguageModel"
|
||||||
|
]
|
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Base language model."""
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageModel:
|
||||||
|
"""Define base language model."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def emit(self, **kwargs):
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,129 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Modified masked language model."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.utils import Dictionary
|
||||||
|
from .base import LanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class LooseMaskedLanguageModel(LanguageModel):
|
||||||
|
"""
|
||||||
|
Modified mask operation on sentence.
|
||||||
|
|
||||||
|
If k is assigned, then mask sentence with length k.
|
||||||
|
Otherwise, use mask_ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k (int): Length of fragment.
|
||||||
|
mask_ratio (float): Mask ratio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, k: int = None, mask_ratio=0.5,
|
||||||
|
mask_all_prob=None):
|
||||||
|
super(LooseMaskedLanguageModel, self).__init__()
|
||||||
|
self.mask_ratio = mask_ratio
|
||||||
|
self._k = k
|
||||||
|
self._threshold = mask_all_prob
|
||||||
|
|
||||||
|
def emit(self, sentence: np.ndarray, vocabulary: Dictionary):
|
||||||
|
"""
|
||||||
|
Mask mono source sentence.
|
||||||
|
|
||||||
|
A sample used to train model is processed with following step:
|
||||||
|
|
||||||
|
encoder input (source): [x1, x2, x3, x4, x5, x6, x7, x8, </eos>]
|
||||||
|
masked encoder input: [x1, x2, x3, _, _, _, x7, x8, </eos>]
|
||||||
|
decoder input: [ -, x3, x4, x5]
|
||||||
|
| | | |
|
||||||
|
V V V V
|
||||||
|
decoder output: [x3, x4, x5, x6]
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
A simple rule is made that source sentence starts without <BOS>
|
||||||
|
but end with <EOS>.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocabulary (Dictionary): Vocabulary.
|
||||||
|
sentence (np.ndarray): Raw sentence instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict, an example.
|
||||||
|
"""
|
||||||
|
# If v=0, then u must equal to 0. [u, v)
|
||||||
|
u, v = self._get_masked_interval(sentence.shape[0],
|
||||||
|
self._k, self._threshold)
|
||||||
|
|
||||||
|
encoder_input = sentence.copy()
|
||||||
|
right_shifted_sentence = np.concatenate(([vocabulary.bos_index], sentence[:-1]))
|
||||||
|
if u == 0:
|
||||||
|
_len = v - u if v - u != 0 else sentence.shape[0]
|
||||||
|
decoder_input = right_shifted_sentence[:_len]
|
||||||
|
decoder_input[0] = vocabulary.mask_index
|
||||||
|
decoder_output = sentence[:_len].copy()
|
||||||
|
else:
|
||||||
|
decoder_input = right_shifted_sentence[u - 1:v]
|
||||||
|
decoder_input[0] = vocabulary.mask_index
|
||||||
|
decoder_output = sentence[u - 1:v].copy()
|
||||||
|
|
||||||
|
if v == 0:
|
||||||
|
decoder_input[:] = vocabulary.mask_index
|
||||||
|
else:
|
||||||
|
encoder_input[np.arange(start=u, stop=v)] = vocabulary.mask_index
|
||||||
|
|
||||||
|
if u != v and u > 1:
|
||||||
|
padding = np.array([vocabulary.padding_index] * (u - 1), dtype=np.int32)
|
||||||
|
decoder_input = np.concatenate((padding, decoder_input))
|
||||||
|
decoder_output = np.concatenate((padding, decoder_output))
|
||||||
|
|
||||||
|
if decoder_input.shape[0] != decoder_output.shape[0]:
|
||||||
|
raise ValueError("seq len must equal.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sentence_length": sentence.shape[0],
|
||||||
|
"tgt_sen_length": decoder_output.shape[0],
|
||||||
|
"encoder_input": encoder_input, # end with </eos>
|
||||||
|
"decoder_input": decoder_input,
|
||||||
|
"decoder_output": decoder_output # end with </eos>
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_masked_interval(self, length, fix_length=None,
|
||||||
|
threshold_to_mask_all=None):
|
||||||
|
"""
|
||||||
|
Generate a sequence length according to length and mask_ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length (int): Sequence length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int], [start position, end position].
|
||||||
|
"""
|
||||||
|
# Can not larger than sequence length.
|
||||||
|
# Mask_length belongs to [0, length].
|
||||||
|
if fix_length is not None:
|
||||||
|
interval_length = min(length, fix_length)
|
||||||
|
else:
|
||||||
|
interval_length = min(length, round(self.mask_ratio * length))
|
||||||
|
|
||||||
|
_magic = np.random.random()
|
||||||
|
if threshold_to_mask_all is not None and _magic <= threshold_to_mask_all:
|
||||||
|
return 0, length
|
||||||
|
|
||||||
|
# If not sequence to be masked, then return 0, 0.
|
||||||
|
if interval_length == 0:
|
||||||
|
return 0, 0
|
||||||
|
# Otherwise, return start position and interval length.
|
||||||
|
start_pos = np.random.randint(low=0, high=length - interval_length + 1)
|
||||||
|
return start_pos, start_pos + interval_length
|
@ -0,0 +1,128 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Masked language model."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import LanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedLanguageModel(LanguageModel):
|
||||||
|
"""
|
||||||
|
Do mask operation on sentence.
|
||||||
|
|
||||||
|
If k is assigned, then mask sentence with length k.
|
||||||
|
Otherwise, use mask_ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k (int): Length of fragment.
|
||||||
|
mask_ratio (float): Mask ratio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, k: int = None, mask_ratio=0.5,
|
||||||
|
mask_all_prob=None):
|
||||||
|
super(MaskedLanguageModel, self).__init__()
|
||||||
|
self.mask_ratio = mask_ratio
|
||||||
|
self._k = k
|
||||||
|
self._threshold = mask_all_prob
|
||||||
|
|
||||||
|
def emit(self, sentence: np.ndarray, vocabulary):
|
||||||
|
"""
|
||||||
|
Mask mono source sentence.
|
||||||
|
|
||||||
|
A sample used to train model is processed with following step:
|
||||||
|
|
||||||
|
encoder input (source): [x1, x2, x3, x4, x5, x6, x7, x8, </eos>]
|
||||||
|
masked encoder input: [x1, x2, _, _, _, x6, x7, x8, </eos>]
|
||||||
|
decoder input: [ _, x3, x4]
|
||||||
|
| | |
|
||||||
|
V V V
|
||||||
|
decoder output: [ x3, x4, x5]
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
A simple rule is made that source sentence starts without <BOS>
|
||||||
|
but end with <EOS>.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocabulary (Dictionary): Vocabulary.
|
||||||
|
sentence (np.ndarray): Raw sentence instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict, an example.
|
||||||
|
"""
|
||||||
|
encoder_input = sentence.copy()
|
||||||
|
seq_len = encoder_input.shape[0]
|
||||||
|
|
||||||
|
# If v=0, then u must equal to 0. [u, v)
|
||||||
|
u, v = self._get_masked_interval(len(encoder_input),
|
||||||
|
self._k, self._threshold)
|
||||||
|
|
||||||
|
if u == 0:
|
||||||
|
_len = v - u if v - u != 0 else seq_len
|
||||||
|
decoder_input = np.array([vocabulary.mask_index] * _len, dtype=np.int32)
|
||||||
|
decoder_input[1:] = encoder_input[:_len - 1].copy()
|
||||||
|
else:
|
||||||
|
decoder_input = np.array([vocabulary.mask_index] * (v - u), dtype=np.int32)
|
||||||
|
decoder_input[1:] = encoder_input[u:v - 1].copy()
|
||||||
|
|
||||||
|
if v == 0:
|
||||||
|
decoder_output = encoder_input.copy()
|
||||||
|
encoder_input[:] = vocabulary.mask_index
|
||||||
|
else:
|
||||||
|
decoder_output = encoder_input[u:v].copy()
|
||||||
|
encoder_input[np.arange(start=u, stop=v)] = vocabulary.mask_index
|
||||||
|
|
||||||
|
if u != v and u > 0:
|
||||||
|
padding = np.array([vocabulary.padding_index] * u, dtype=np.int32)
|
||||||
|
decoder_input = np.concatenate((padding, decoder_input))
|
||||||
|
decoder_output = np.concatenate((padding, decoder_output))
|
||||||
|
|
||||||
|
assert decoder_input.shape[0] == decoder_output.shape[0], "seq len must equal."
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sentence_length": seq_len,
|
||||||
|
"tgt_sen_length": decoder_output.shape[0],
|
||||||
|
"encoder_input": encoder_input, # end with </eos>
|
||||||
|
"decoder_input": decoder_input,
|
||||||
|
"decoder_output": decoder_output # end with </eos>
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_masked_interval(self, length, fix_length=None,
|
||||||
|
threshold_to_mask_all=None):
|
||||||
|
"""
|
||||||
|
Generate a sequence length according to length and mask_ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length (int): Sequence length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int], [start position, end position].
|
||||||
|
"""
|
||||||
|
# Can not larger than sequence length.
|
||||||
|
# Mask_length belongs to [0, length].
|
||||||
|
if fix_length is not None:
|
||||||
|
interval_length = min(length, fix_length)
|
||||||
|
else:
|
||||||
|
interval_length = min(length, round(self.mask_ratio * length))
|
||||||
|
|
||||||
|
_magic = np.random.random()
|
||||||
|
if threshold_to_mask_all is not None and _magic <= threshold_to_mask_all:
|
||||||
|
return 0, length
|
||||||
|
|
||||||
|
# If not sequence to be masked, then return 0, 0.
|
||||||
|
if interval_length == 0:
|
||||||
|
return 0, 0
|
||||||
|
# Otherwise, return start position and interval length.
|
||||||
|
start_pos = np.random.randint(low=0, high=length - interval_length + 1)
|
||||||
|
return start_pos, start_pos + interval_length
|
@ -0,0 +1,72 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""Noise channel language model."""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from .base import LanguageModel
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseChannelLanguageModel(LanguageModel):
|
||||||
|
"""Do mask on bilingual data."""
|
||||||
|
|
||||||
|
def __init__(self, add_noise_prob: float = 0.1):
|
||||||
|
super(NoiseChannelLanguageModel, self).__init__()
|
||||||
|
self._noisy_prob = add_noise_prob
|
||||||
|
|
||||||
|
def emit(self, sentence: np.ndarray, target: np.ndarray,
|
||||||
|
mask_symbol_idx: int,
|
||||||
|
bos_symbol_idx: int):
|
||||||
|
"""
|
||||||
|
Add noise to sentence randomly.
|
||||||
|
|
||||||
|
For example, given a sentence pair:
|
||||||
|
source sentence: [x1, x2, x3, x4, x5, x6, </eos>]
|
||||||
|
target sentence: [y1, y2, y3, y4, </eos>]
|
||||||
|
|
||||||
|
After do random mask, data is looked like:
|
||||||
|
encoder input (source): [x1, x2, _, x4, x5, _, </eos>]
|
||||||
|
decoder input: [<bos>, y1, y2, y3, y4]
|
||||||
|
| | | | |
|
||||||
|
V V V V V
|
||||||
|
decoder output: [ y1, y2, y3, y4, </eos>]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sentence (np.ndarray): Raw sentence.
|
||||||
|
target (np.ndarray): Target output (prediction).
|
||||||
|
mask_symbol_idx (int): Index of MASK symbol.
|
||||||
|
bos_symbol_idx (int): Index of bos symbol.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict, an example.
|
||||||
|
"""
|
||||||
|
encoder_input = sentence.copy()
|
||||||
|
tgt_seq_len = target.shape[0]
|
||||||
|
if self._noisy_prob > 0:
|
||||||
|
for i, _ in enumerate(encoder_input):
|
||||||
|
_prob = np.random.random()
|
||||||
|
if _prob < self._noisy_prob:
|
||||||
|
encoder_input[i] = mask_symbol_idx
|
||||||
|
|
||||||
|
decoder_input = np.empty(shape=tgt_seq_len, dtype=np.int64)
|
||||||
|
decoder_input[1:] = target[:-1]
|
||||||
|
decoder_input[0] = bos_symbol_idx
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sentence_length": encoder_input.shape[0],
|
||||||
|
"tgt_sen_length": tgt_seq_len,
|
||||||
|
"encoder_input": encoder_input, # end with </eos>
|
||||||
|
"decoder_input": decoder_input, # start with <bos>
|
||||||
|
"decoder_output": target # end with </eos>
|
||||||
|
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue