!1628 Implements of masked seq2seq pre-training for language generation.
Merge pull request !1628 from 刘崇鸣/model_zoo_masspull/1628/MERGE
commit
4d95e3340c
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,84 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Apply bpe script."""
|
||||
import os
|
||||
import argparse
|
||||
from multiprocessing import Pool, cpu_count
|
||||
|
||||
from src.utils import Dictionary
|
||||
from src.utils import bpe_encode
|
||||
|
||||
parser = argparse.ArgumentParser(description='Apply BPE.')
|
||||
parser.add_argument("--codes", type=str, default="", required=True,
|
||||
help="bpe codes path.")
|
||||
parser.add_argument("--src_folder", type=str, default="", required=True,
|
||||
help="raw corpus folder.")
|
||||
parser.add_argument("--output_folder", type=str, default="", required=True,
|
||||
help="encoded corpus output path.")
|
||||
parser.add_argument("--prefix", type=str, default="", required=False,
|
||||
help="Prefix of text file.")
|
||||
parser.add_argument("--vocab_path", type=str, default="", required=True,
|
||||
help="Generated vocabulary output path.")
|
||||
parser.add_argument("--threshold", type=int, default=None, required=False,
|
||||
help="Filter out words that frequency is lower than threshold.")
|
||||
parser.add_argument("--processes", type=int, default=2, required=False,
|
||||
help="Number of processes to use.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if not (args.codes and args.src_folder and args.output_folder):
|
||||
raise ValueError("Please enter required params.")
|
||||
|
||||
source_folder = args.src_folder
|
||||
output_folder = args.output_folder
|
||||
codes = args.codes
|
||||
|
||||
if not os.path.exists(codes):
|
||||
raise FileNotFoundError("`--codes` is not existed.")
|
||||
if not os.path.exists(source_folder) or not os.path.isdir(source_folder):
|
||||
raise ValueError("`--src_folder` must be a dir and existed.")
|
||||
if not os.path.exists(output_folder) or not os.path.isdir(output_folder):
|
||||
raise ValueError("`--output_folder` must be a dir and existed.")
|
||||
if not isinstance(args.prefix, str) or len(args.prefix) > 128:
|
||||
raise ValueError("`--prefix` must be a str and len <= 128.")
|
||||
if not isinstance(args.processes, int):
|
||||
raise TypeError("`--processes` must be an integer.")
|
||||
|
||||
available_dict = []
|
||||
args_groups = []
|
||||
for file in os.listdir(source_folder):
|
||||
if args.prefix and not file.startswith(args.prefix):
|
||||
continue
|
||||
if file.endswith(".txt"):
|
||||
output_path = os.path.join(output_folder, file.replace(".txt", "_bpe.txt"))
|
||||
dict_path = os.path.join(output_folder, file.replace(".txt", ".dict"))
|
||||
available_dict.append(dict_path)
|
||||
args_groups.append((codes, os.path.join(source_folder, file),
|
||||
output_path, dict_path))
|
||||
|
||||
kernel_size = 1 if args.processes <= 0 else args.processes
|
||||
kernel_size = min(kernel_size, cpu_count())
|
||||
pool = Pool(kernel_size)
|
||||
for arg in args_groups:
|
||||
pool.apply_async(bpe_encode, args=arg)
|
||||
pool.close()
|
||||
pool.join()
|
||||
|
||||
vocab = Dictionary.load_from_text(available_dict)
|
||||
if args.threshold is not None:
|
||||
vocab = vocab.shrink(args.threshold)
|
||||
vocab.persistence(args.vocab_path)
|
||||
print(f" | Vocabulary Size: {len(vocab)}")
|
@ -0,0 +1,20 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""MASS model configuration."""
|
||||
from .config import TransformerConfig
|
||||
|
||||
__all__ = [
|
||||
"TransformerConfig"
|
||||
]
|
@ -0,0 +1,54 @@
|
||||
{
|
||||
"dataset_config": {
|
||||
"epochs": 20,
|
||||
"batch_size": 192,
|
||||
"pre_train_dataset": "",
|
||||
"fine_tune_dataset": "",
|
||||
"test_dataset": "",
|
||||
"valid_dataset": "",
|
||||
"dataset_sink_mode": false,
|
||||
"dataset_sink_step": 100
|
||||
},
|
||||
"model_config": {
|
||||
"random_seed": 100,
|
||||
"save_graphs": false,
|
||||
"seq_length": 64,
|
||||
"vocab_size": 45744,
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 6,
|
||||
"num_attention_heads": 8,
|
||||
"intermediate_size": 4096,
|
||||
"hidden_act": "relu",
|
||||
"hidden_dropout_prob": 0.2,
|
||||
"attention_dropout_prob": 0.2,
|
||||
"max_position_embeddings": 64,
|
||||
"initializer_range": 0.02,
|
||||
"label_smoothing": 0.1,
|
||||
"beam_width": 4,
|
||||
"length_penalty_weight": 1.0,
|
||||
"max_decode_length": 64,
|
||||
"input_mask_from_dataset": true
|
||||
},
|
||||
"loss_scale_config": {
|
||||
"init_loss_scale": 65536,
|
||||
"loss_scale_factor": 2,
|
||||
"scale_window": 200
|
||||
},
|
||||
"learn_rate_config": {
|
||||
"optimizer": "adam",
|
||||
"lr": 1e-4,
|
||||
"lr_scheduler": "poly",
|
||||
"poly_lr_scheduler_power": 0.5,
|
||||
"decay_steps": 10000,
|
||||
"decay_start_step": 12000,
|
||||
"warmup_steps": 4000,
|
||||
"min_lr": 1e-6
|
||||
},
|
||||
"checkpoint_options": {
|
||||
"existed_ckpt": "",
|
||||
"save_ckpt_steps": 2500,
|
||||
"keep_ckpt_max": 50,
|
||||
"ckpt_prefix": "ckpt",
|
||||
"ckpt_path": "checkpoints"
|
||||
}
|
||||
}
|
@ -0,0 +1,232 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Configuration class for Transformer."""
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
from typing import List
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
|
||||
def _is_dataset_file(file: str):
|
||||
return "tfrecord" in file.lower() or "mindrecord" in file.lower()
|
||||
|
||||
|
||||
def _get_files_from_dir(folder: str):
|
||||
_files = []
|
||||
for file in os.listdir(folder):
|
||||
if _is_dataset_file(file):
|
||||
_files.append(os.path.join(folder, file))
|
||||
return _files
|
||||
|
||||
|
||||
def get_source_list(folder: str) -> List:
|
||||
"""
|
||||
Get file list from a folder.
|
||||
|
||||
Returns:
|
||||
list, file list.
|
||||
"""
|
||||
_list = []
|
||||
if not folder:
|
||||
return _list
|
||||
|
||||
if os.path.isdir(folder):
|
||||
_list = _get_files_from_dir(folder)
|
||||
else:
|
||||
if _is_dataset_file(folder):
|
||||
_list.append(folder)
|
||||
return _list
|
||||
|
||||
|
||||
PARAM_NODES = {"dataset_config",
|
||||
"model_config",
|
||||
"loss_scale_config",
|
||||
"learn_rate_config",
|
||||
"checkpoint_options"}
|
||||
|
||||
|
||||
class TransformerConfig:
|
||||
"""
|
||||
Configuration for `Transformer`.
|
||||
|
||||
Args:
|
||||
random_seed (int): Random seed.
|
||||
batch_size (int): Batch size of input dataset.
|
||||
epochs (int): Epoch number.
|
||||
dataset_sink_mode (bool): Whether enable dataset sink mode.
|
||||
dataset_sink_step (int): Dataset sink step.
|
||||
lr_scheduler (str): Whether use lr_scheduler, only support "ISR" now.
|
||||
lr (float): Initial learning rate.
|
||||
min_lr (float): Minimum learning rate.
|
||||
decay_start_step (int): Step to decay.
|
||||
warmup_steps (int): Warm up steps.
|
||||
dataset_schema (str): Path of dataset schema file.
|
||||
pre_train_dataset (str): Path of pre-training dataset file or folder.
|
||||
fine_tune_dataset (str): Path of fine-tune dataset file or folder.
|
||||
test_dataset (str): Path of test dataset file or folder.
|
||||
valid_dataset (str): Path of validation dataset file or folder.
|
||||
ckpt_path (str): Checkpoints save path.
|
||||
save_ckpt_steps (int): Interval of saving ckpt.
|
||||
ckpt_prefix (str): Prefix of ckpt file.
|
||||
keep_ckpt_max (int): Max ckpt files number.
|
||||
seq_length (int): Length of input sequence. Default: 64.
|
||||
vocab_size (int): The shape of each embedding vector. Default: 46192.
|
||||
hidden_size (int): Size of embedding, attention, dim. Default: 512.
|
||||
num_hidden_layers (int): Encoder, Decoder layers.
|
||||
num_attention_heads (int): Number of hidden layers in the Transformer encoder/decoder
|
||||
cell. Default: 6.
|
||||
intermediate_size (int): Size of intermediate layer in the Transformer
|
||||
encoder/decoder cell. Default: 4096.
|
||||
hidden_act (str): Activation function used in the Transformer encoder/decoder
|
||||
cell. Default: "relu".
|
||||
init_loss_scale (int): Initialized loss scale.
|
||||
loss_scale_factor (int): Loss scale factor.
|
||||
scale_window (int): Window size of loss scale.
|
||||
beam_width (int): Beam width for beam search in inferring. Default: 4.
|
||||
length_penalty_weight (float): Penalty for sentence length. Default: 1.0.
|
||||
label_smoothing (float): Label smoothing setting. Default: 0.1.
|
||||
input_mask_from_dataset (bool): Specifies whether to use the input mask that loaded from
|
||||
dataset. Default: True.
|
||||
save_graphs (bool): Whether to save graphs, please set to True if mindinsight
|
||||
is wanted.
|
||||
dtype (mstype): Data type of the input. Default: mstype.float32.
|
||||
max_decode_length (int): Max decode length for inferring. Default: 64.
|
||||
hidden_dropout_prob (float): The dropout probability for hidden outputs. Default: 0.1.
|
||||
attention_dropout_prob (float): The dropout probability for
|
||||
Multi-head Self-Attention. Default: 0.1.
|
||||
max_position_embeddings (int): Maximum length of sequences used in this
|
||||
model. Default: 512.
|
||||
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
random_seed=74,
|
||||
batch_size=64, epochs=1,
|
||||
dataset_sink_mode=True, dataset_sink_step=1,
|
||||
lr_scheduler="", optimizer="adam",
|
||||
lr=1e-4, min_lr=1e-6,
|
||||
decay_steps=10000, poly_lr_scheduler_power=1,
|
||||
decay_start_step=-1, warmup_steps=2000,
|
||||
pre_train_dataset: str = None,
|
||||
fine_tune_dataset: str = None,
|
||||
test_dataset: str = None,
|
||||
valid_dataset: str = None,
|
||||
ckpt_path: str = None,
|
||||
save_ckpt_steps=2000,
|
||||
ckpt_prefix="CKPT",
|
||||
existed_ckpt="",
|
||||
keep_ckpt_max=20,
|
||||
seq_length=128,
|
||||
vocab_size=46192,
|
||||
hidden_size=512,
|
||||
num_hidden_layers=6,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=4096,
|
||||
hidden_act="relu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_dropout_prob=0.1,
|
||||
max_position_embeddings=64,
|
||||
initializer_range=0.02,
|
||||
init_loss_scale=2 ** 10,
|
||||
loss_scale_factor=2, scale_window=2000,
|
||||
beam_width=5,
|
||||
length_penalty_weight=1.0,
|
||||
label_smoothing=0.1,
|
||||
input_mask_from_dataset=True,
|
||||
save_graphs=False,
|
||||
dtype=mstype.float32,
|
||||
max_decode_length=64):
|
||||
|
||||
self.save_graphs = save_graphs
|
||||
self.random_seed = random_seed
|
||||
self.pre_train_dataset = get_source_list(pre_train_dataset) # type: List[str]
|
||||
self.fine_tune_dataset = get_source_list(fine_tune_dataset) # type: List[str]
|
||||
self.valid_dataset = get_source_list(valid_dataset) # type: List[str]
|
||||
self.test_dataset = get_source_list(test_dataset) # type: List[str]
|
||||
|
||||
if not isinstance(epochs, int) and epochs < 0:
|
||||
raise ValueError("`epoch` must be type of int.")
|
||||
|
||||
self.epochs = epochs
|
||||
self.dataset_sink_mode = dataset_sink_mode
|
||||
self.dataset_sink_step = dataset_sink_step
|
||||
|
||||
self.ckpt_path = ckpt_path
|
||||
self.keep_ckpt_max = keep_ckpt_max
|
||||
self.save_ckpt_steps = save_ckpt_steps
|
||||
self.ckpt_prefix = ckpt_prefix
|
||||
self.existed_ckpt = existed_ckpt
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_dropout_prob = attention_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.initializer_range = initializer_range
|
||||
self.label_smoothing = label_smoothing
|
||||
|
||||
self.beam_width = beam_width
|
||||
self.length_penalty_weight = length_penalty_weight
|
||||
self.max_decode_length = max_decode_length
|
||||
self.input_mask_from_dataset = input_mask_from_dataset
|
||||
self.compute_type = mstype.float16
|
||||
self.dtype = dtype
|
||||
|
||||
self.scale_window = scale_window
|
||||
self.loss_scale_factor = loss_scale_factor
|
||||
self.init_loss_scale = init_loss_scale
|
||||
|
||||
self.optimizer = optimizer
|
||||
self.lr = lr
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.min_lr = min_lr
|
||||
self.poly_lr_scheduler_power = poly_lr_scheduler_power
|
||||
self.decay_steps = decay_steps
|
||||
self.decay_start_step = decay_start_step
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
self.train_url = ""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_object: dict):
|
||||
"""Constructs a `TransformerConfig` from a Python dictionary of parameters."""
|
||||
_params = {}
|
||||
for node in PARAM_NODES:
|
||||
for key in json_object[node]:
|
||||
_params[key] = json_object[node][key]
|
||||
return cls(**_params)
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file):
|
||||
"""Constructs a `TransformerConfig` from a json file of parameters."""
|
||||
with open(json_file, "r") as reader:
|
||||
return cls.from_dict(json.load(reader))
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
return output
|
||||
|
||||
def to_json_string(self):
|
||||
"""Serializes this instance to a JSON string."""
|
||||
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
@ -0,0 +1,110 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate Cornell Movie Dialog dataset."""
|
||||
import os
|
||||
import argparse
|
||||
from src.dataset import BiLingualDataLoader
|
||||
from src.language_model import NoiseChannelLanguageModel
|
||||
from src.utils import Dictionary
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate Cornell Movie Dialog dataset file.')
|
||||
parser.add_argument("--src_folder", type=str, default="", required=True,
|
||||
help="Raw corpus folder.")
|
||||
parser.add_argument("--existed_vocab", type=str, default="", required=True,
|
||||
help="Existed vocabulary.")
|
||||
parser.add_argument("--train_prefix", type=str, default="train", required=False,
|
||||
help="Prefix of train file.")
|
||||
parser.add_argument("--test_prefix", type=str, default="test", required=False,
|
||||
help="Prefix of test file.")
|
||||
parser.add_argument("--valid_prefix", type=str, default=None, required=False,
|
||||
help="Prefix of valid file.")
|
||||
parser.add_argument("--noise_prob", type=float, default=0., required=False,
|
||||
help="Add noise prob.")
|
||||
parser.add_argument("--max_len", type=int, default=32, required=False,
|
||||
help="Max length of sentence.")
|
||||
parser.add_argument("--output_folder", type=str, default="", required=True,
|
||||
help="Dataset output path.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
dicts = []
|
||||
train_src_file = ""
|
||||
train_tgt_file = ""
|
||||
test_src_file = ""
|
||||
test_tgt_file = ""
|
||||
valid_src_file = ""
|
||||
valid_tgt_file = ""
|
||||
for file in os.listdir(args.src_folder):
|
||||
if file.startswith(args.train_prefix) and "src" in file and file.endswith(".txt"):
|
||||
train_src_file = os.path.join(args.src_folder, file)
|
||||
elif file.startswith(args.train_prefix) and "tgt" in file and file.endswith(".txt"):
|
||||
train_tgt_file = os.path.join(args.src_folder, file)
|
||||
elif file.startswith(args.test_prefix) and "src" in file and file.endswith(".txt"):
|
||||
test_src_file = os.path.join(args.src_folder, file)
|
||||
elif file.startswith(args.test_prefix) and "tgt" in file and file.endswith(".txt"):
|
||||
test_tgt_file = os.path.join(args.src_folder, file)
|
||||
elif args.valid_prefix and file.startswith(args.valid_prefix) and "src" in file and file.endswith(".txt"):
|
||||
valid_src_file = os.path.join(args.src_folder, file)
|
||||
elif args.valid_prefix and file.startswith(args.valid_prefix) and "tgt" in file and file.endswith(".txt"):
|
||||
valid_tgt_file = os.path.join(args.src_folder, file)
|
||||
else:
|
||||
continue
|
||||
|
||||
vocab = Dictionary.load_from_persisted_dict(args.existed_vocab)
|
||||
|
||||
if train_src_file and train_tgt_file:
|
||||
BiLingualDataLoader(
|
||||
src_filepath=train_src_file,
|
||||
tgt_filepath=train_tgt_file,
|
||||
src_dict=vocab, tgt_dict=vocab,
|
||||
src_lang="en", tgt_lang="en",
|
||||
language_model=NoiseChannelLanguageModel(add_noise_prob=args.noise_prob),
|
||||
max_sen_len=args.max_len
|
||||
).write_to_tfrecord(
|
||||
path=os.path.join(
|
||||
args.output_folder, "train_cornell_dialog.tfrecord"
|
||||
)
|
||||
)
|
||||
|
||||
if test_src_file and test_tgt_file:
|
||||
BiLingualDataLoader(
|
||||
src_filepath=test_src_file,
|
||||
tgt_filepath=test_tgt_file,
|
||||
src_dict=vocab, tgt_dict=vocab,
|
||||
src_lang="en", tgt_lang="en",
|
||||
language_model=NoiseChannelLanguageModel(add_noise_prob=0.),
|
||||
max_sen_len=args.max_len
|
||||
).write_to_tfrecord(
|
||||
path=os.path.join(
|
||||
args.output_folder, "test_cornell_dialog.tfrecord"
|
||||
)
|
||||
)
|
||||
|
||||
if args.valid_prefix:
|
||||
BiLingualDataLoader(
|
||||
src_filepath=os.path.join(args.src_folder, valid_src_file),
|
||||
tgt_filepath=os.path.join(args.src_folder, valid_tgt_file),
|
||||
src_dict=vocab, tgt_dict=vocab,
|
||||
src_lang="en", tgt_lang="en",
|
||||
language_model=NoiseChannelLanguageModel(add_noise_prob=0.),
|
||||
max_sen_len=args.max_len
|
||||
).write_to_tfrecord(
|
||||
path=os.path.join(
|
||||
args.output_folder, "valid_cornell_dialog.tfrecord"
|
||||
)
|
||||
)
|
||||
|
||||
print(f" | Vocabulary size: {vocab.size}.")
|
@ -0,0 +1,75 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Evaluation api."""
|
||||
import argparse
|
||||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
from config import TransformerConfig
|
||||
from src.transformer import infer
|
||||
from src.utils import ngram_ppl
|
||||
from src.utils import Dictionary
|
||||
from src.utils import rouge
|
||||
|
||||
parser = argparse.ArgumentParser(description='Evaluation MASS.')
|
||||
parser.add_argument("--config", type=str, required=True,
|
||||
help="Model config json file path.")
|
||||
parser.add_argument("--vocab", type=str, required=True,
|
||||
help="Vocabulary to use.")
|
||||
parser.add_argument("--output", type=str, required=True,
|
||||
help="Result file path.")
|
||||
|
||||
|
||||
def get_config(config):
|
||||
config = TransformerConfig.from_json_file(config)
|
||||
config.compute_type = mstype.float16
|
||||
config.dtype = mstype.float32
|
||||
return config
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
vocab = Dictionary.load_from_persisted_dict(args.vocab)
|
||||
_config = get_config(args.config)
|
||||
result = infer(_config)
|
||||
with open(args.output, "wb") as f:
|
||||
pickle.dump(result, f, 1)
|
||||
|
||||
ppl_score = 0.
|
||||
preds = []
|
||||
tgts = []
|
||||
_count = 0
|
||||
for sample in result:
|
||||
sentence_prob = np.array(sample['prediction_prob'], dtype=np.float32)
|
||||
sentence_prob = sentence_prob[:, 1:]
|
||||
_ppl = []
|
||||
for path in sentence_prob:
|
||||
_ppl.append(ngram_ppl(path, log_softmax=True))
|
||||
ppl = np.min(_ppl)
|
||||
preds.append(' '.join([vocab[t] for t in sample['prediction']]))
|
||||
tgts.append(' '.join([vocab[t] for t in sample['target']]))
|
||||
print(f" | source: {' '.join([vocab[t] for t in sample['source']])}")
|
||||
print(f" | target: {tgts[-1]}")
|
||||
print(f" | prediction: {preds[-1]}")
|
||||
print(f" | ppl: {ppl}.")
|
||||
if np.isinf(ppl):
|
||||
continue
|
||||
ppl_score += ppl
|
||||
_count += 1
|
||||
|
||||
print(f" | PPL={ppl_score / _count}.")
|
||||
rouge(preds, tgts)
|
@ -0,0 +1,84 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate Gigaword dataset."""
|
||||
import os
|
||||
import argparse
|
||||
|
||||
from src.dataset import BiLingualDataLoader
|
||||
from src.language_model import NoiseChannelLanguageModel
|
||||
from src.utils import Dictionary
|
||||
|
||||
parser = argparse.ArgumentParser(description='Create Gigaword fine-tune Dataset.')
|
||||
parser.add_argument("--train_src", type=str, default="", required=False,
|
||||
help="train dataset source file path.")
|
||||
parser.add_argument("--train_ref", type=str, default="", required=False,
|
||||
help="train dataset reference file path.")
|
||||
parser.add_argument("--test_src", type=str, default="", required=False,
|
||||
help="test dataset source file path.")
|
||||
parser.add_argument("--test_ref", type=str, default="", required=False,
|
||||
help="test dataset reference file path.")
|
||||
parser.add_argument("--noise_prob", type=float, default=0., required=False,
|
||||
help="add noise prob.")
|
||||
parser.add_argument("--existed_vocab", type=str, default="", required=False,
|
||||
help="existed vocab path.")
|
||||
parser.add_argument("--max_len", type=int, default=64, required=False,
|
||||
help="max length of sentences.")
|
||||
parser.add_argument("--output_folder", type=str, default="", required=True,
|
||||
help="dataset output path.")
|
||||
parser.add_argument("--format", type=str, default="tfrecord", required=False,
|
||||
help="dataset format.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
vocab = Dictionary.load_from_persisted_dict(args.existed_vocab)
|
||||
|
||||
if args.train_src and args.train_ref:
|
||||
train = BiLingualDataLoader(
|
||||
src_filepath=args.train_src,
|
||||
tgt_filepath=args.train_ref,
|
||||
src_dict=vocab, tgt_dict=vocab,
|
||||
src_lang="en", tgt_lang="en",
|
||||
language_model=NoiseChannelLanguageModel(add_noise_prob=args.noise_prob),
|
||||
max_sen_len=args.max_len
|
||||
)
|
||||
if "tf" in args.format.lower():
|
||||
train.write_to_tfrecord(
|
||||
path=os.path.join(args.output_folder, "gigaword_train_dataset.tfrecord")
|
||||
)
|
||||
else:
|
||||
train.write_to_mindrecord(
|
||||
path=os.path.join(args.output_folder, "gigaword_train_dataset.mindrecord")
|
||||
)
|
||||
|
||||
if args.test_src and args.test_ref:
|
||||
test = BiLingualDataLoader(
|
||||
src_filepath=args.test_src,
|
||||
tgt_filepath=args.test_ref,
|
||||
src_dict=vocab, tgt_dict=vocab,
|
||||
src_lang="en", tgt_lang="en",
|
||||
language_model=NoiseChannelLanguageModel(add_noise_prob=0),
|
||||
max_sen_len=args.max_len
|
||||
)
|
||||
if "tf" in args.format.lower():
|
||||
test.write_to_tfrecord(
|
||||
path=os.path.join(args.output_folder, "gigaword_test_dataset.tfrecord")
|
||||
)
|
||||
else:
|
||||
test.write_to_mindrecord(
|
||||
path=os.path.join(args.output_folder, "gigaword_test_dataset.mindrecord")
|
||||
)
|
||||
|
||||
print(f" | Vocabulary size: {vocab.size}.")
|
@ -0,0 +1,58 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate News Crawl corpus dataset."""
|
||||
import argparse
|
||||
|
||||
from src.utils import Dictionary
|
||||
from src.utils.preprocess import create_pre_training_dataset
|
||||
|
||||
parser = argparse.ArgumentParser(description='Create News Crawl Pre-Training Dataset.')
|
||||
parser.add_argument("--src_folder", type=str, default="", required=True,
|
||||
help="Raw corpus folder.")
|
||||
parser.add_argument("--existed_vocab", type=str, default="", required=True,
|
||||
help="Existed vocab path.")
|
||||
parser.add_argument("--mask_ratio", type=float, default=0.4, required=True,
|
||||
help="Mask ratio.")
|
||||
parser.add_argument("--output_folder", type=str, default="", required=True,
|
||||
help="Dataset output path.")
|
||||
parser.add_argument("--max_len", type=int, default=32, required=False,
|
||||
help="Max length of sentences.")
|
||||
parser.add_argument("--suffix", type=str, default="", required=False,
|
||||
help="Add suffix to output file.")
|
||||
parser.add_argument("--processes", type=int, default=2, required=False,
|
||||
help="Size of processes pool.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, _ = parser.parse_known_args()
|
||||
if not (args.src_folder and args.output_folder):
|
||||
raise ValueError("Please enter required params.")
|
||||
|
||||
if not args.existed_vocab:
|
||||
raise ValueError("`--existed_vocab` is required.")
|
||||
|
||||
vocab = Dictionary.load_from_persisted_dict(args.existed_vocab)
|
||||
|
||||
create_pre_training_dataset(
|
||||
folder_path=args.src_folder,
|
||||
output_folder_path=args.output_folder,
|
||||
vocabulary=vocab,
|
||||
prefix="news.20", suffix=args.suffix,
|
||||
mask_ratio=args.mask_ratio,
|
||||
min_sen_len=10,
|
||||
max_sen_len=args.max_len,
|
||||
dataset_type="tfrecord",
|
||||
cores=args.processes
|
||||
)
|
||||
print(f" | Vocabulary size: {vocab.size}.")
|
@ -0,0 +1,5 @@
|
||||
nltk
|
||||
jieba
|
||||
numpy
|
||||
subword-nmt
|
||||
files2rouge
|
@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
src_folder_path=$1 # source text folder path.
|
||||
|
||||
cd $src_folder_path || exit
|
||||
cat *.txt | subword-nmt learn-bpe -s 46000 -o all.bpe.codes
|
@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
options=`getopt -u -o ht:n:i:j:c:o:v: -l help,task:,device_num:,device_id:,hccl_json:,config:,output:,vocab -- "$@"`
|
||||
eval set -- "$options"
|
||||
echo $options
|
||||
|
||||
echo_help()
|
||||
{
|
||||
echo "Usage:"
|
||||
echo "bash train.sh [-h] [-t t|i] [-n N] [-i N] [-j FILE] [-c FILE] [-o FILE] [-v FILE]"
|
||||
echo "options:"
|
||||
echo " -h --help show usage"
|
||||
echo " -t --task select task, 't' for training and 'i' for inference"
|
||||
echo " -n --device_num training with N devices"
|
||||
echo " -i --device_id training with device i"
|
||||
echo " -j --hccl_json set the rank table file"
|
||||
echo " -c --config set the configuration file"
|
||||
echo " -o --output set the output file of inference"
|
||||
echo " -v --vocab set the vocabulary"
|
||||
}
|
||||
|
||||
set_hccl_json()
|
||||
{
|
||||
while [ -n "$1" ]
|
||||
do
|
||||
if [[ "$1" == "-j" || "$1" == "--hccl_json" ]]
|
||||
then
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$2 #/data/wsc/hccl_2p_01.json
|
||||
export RANK_TABLE_FILE=$2 #/data/wsc/hccl_2p_01.json
|
||||
break
|
||||
fi
|
||||
shift
|
||||
done
|
||||
}
|
||||
set_device_id()
|
||||
{
|
||||
while [ -n "$1" ]
|
||||
do
|
||||
if [[ "$1" == "-i" || "$1" == "--device_id" ]]
|
||||
then
|
||||
if [[ $2 -ge 0 && $2 -le 7 ]]
|
||||
then
|
||||
export DEVICE_ID=$2
|
||||
fi
|
||||
break
|
||||
fi
|
||||
shift
|
||||
done
|
||||
}
|
||||
|
||||
while [ -n "$1" ]
|
||||
do
|
||||
case "$1" in
|
||||
-h|--help)
|
||||
echo_help
|
||||
shift
|
||||
;;
|
||||
-t|--task)
|
||||
echo "task:"
|
||||
if [ "$2" == "t" ]
|
||||
then
|
||||
task=train
|
||||
elif [ "$2" == "i" ]
|
||||
then
|
||||
task=infer
|
||||
fi
|
||||
shift 2
|
||||
;;
|
||||
-n|--device_num)
|
||||
echo "device_num"
|
||||
if [ $2 -eq 1 ]
|
||||
then
|
||||
set_device_id $options
|
||||
elif [ $2 -gt 1 ]
|
||||
then
|
||||
export HCCL_FLAG=1
|
||||
export DEPLOY_MODE=0
|
||||
|
||||
export RANK_SIZE=$2
|
||||
set_hccl_json $options
|
||||
fi
|
||||
shift 2
|
||||
;;
|
||||
-i|--device_id)
|
||||
echo "set device id"
|
||||
export DEVICE_ID=$2
|
||||
shift 2
|
||||
;;
|
||||
-c|--config)
|
||||
echo "config";
|
||||
configurations=$2
|
||||
shift 2
|
||||
;;
|
||||
-o|--output)
|
||||
echo "output";
|
||||
output=$2
|
||||
shift 2
|
||||
;;
|
||||
-v|--vocab)
|
||||
echo "vocab";
|
||||
vocab=$2
|
||||
shift 2
|
||||
;;
|
||||
--)
|
||||
shift
|
||||
break
|
||||
;;
|
||||
*)
|
||||
shift
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
for((i=0; i < $RANK_SIZE; i++))
|
||||
do
|
||||
if [ $RANK_SIZE -gt 1 ]
|
||||
then
|
||||
echo $RANK_SIZE
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$[i]
|
||||
fi
|
||||
echo "Working on device $i"
|
||||
|
||||
file_path=$(cd "$(dirname $0)" || exit; pwd)
|
||||
cd $file_path || exit
|
||||
cd ../ || exit
|
||||
|
||||
rm -rf ./run_mass_$DEVICE_ID
|
||||
mkdir ./run_mass_$DEVICE_ID
|
||||
|
||||
cp train.py ./run_mass_$DEVICE_ID
|
||||
cp eval.py ./run_mass_$DEVICE_ID
|
||||
cp $configurations ./run_mass_$DEVICE_ID
|
||||
|
||||
if [ $vocab ]
|
||||
then
|
||||
cp $vocab ./run_mass_$DEVICE_ID
|
||||
fi
|
||||
|
||||
cd ./run_mass_$DEVICE_ID || exit
|
||||
env > log.log
|
||||
echo $task
|
||||
if [ "$task" == "train" ]
|
||||
then
|
||||
python train.py --config ${configurations##*/} >>log.log 2>&1 &
|
||||
elif [ "$task" == "infer" ]
|
||||
then
|
||||
python eval.py --config ${configurations##*/} --output ${output} --vocab ${vocab##*/} >>log_infer.log 2>&1 &
|
||||
fi
|
||||
cd ../
|
||||
done
|
@ -0,0 +1,44 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Source of mass model."""
|
||||
from .dataset import load_dataset
|
||||
from .dataset import bi_data_loader
|
||||
from .dataset import mono_data_loader
|
||||
from .transformer import TransformerDecoder
|
||||
from .transformer import TransformerEncoder
|
||||
from .transformer import Transformer
|
||||
from .transformer import TransformerNetworkWithLoss
|
||||
from .transformer import LabelSmoothedCrossEntropyCriterion
|
||||
from .transformer import TransformerTrainOneStepWithLossScaleCell
|
||||
from .transformer import TransformerTraining
|
||||
from .transformer import infer
|
||||
from .language_model import LooseMaskedLanguageModel
|
||||
from .language_model import MaskedLanguageModel
|
||||
from .language_model import NoiseChannelLanguageModel
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"bi_data_loader",
|
||||
"mono_data_loader",
|
||||
"Transformer",
|
||||
"infer",
|
||||
"TransformerTraining",
|
||||
"TransformerNetworkWithLoss",
|
||||
"TransformerTrainOneStepWithLossScaleCell",
|
||||
"LabelSmoothedCrossEntropyCriterion",
|
||||
"LooseMaskedLanguageModel",
|
||||
"MaskedLanguageModel",
|
||||
"NoiseChannelLanguageModel"
|
||||
]
|
@ -0,0 +1,24 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset module."""
|
||||
from .bi_data_loader import BiLingualDataLoader
|
||||
from .mono_data_loader import MonoLingualDataLoader
|
||||
from .load_dataset import load_dataset
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"BiLingualDataLoader",
|
||||
"MonoLingualDataLoader"
|
||||
]
|
@ -0,0 +1,102 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base class of data loader."""
|
||||
import os
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from .schema import SCHEMA
|
||||
|
||||
|
||||
class DataLoader:
|
||||
"""Data loader for dataset."""
|
||||
_SCHEMA = SCHEMA
|
||||
|
||||
def __init__(self, max_sen_len=66):
|
||||
self._examples = []
|
||||
self._max_sentence_len = max_sen_len
|
||||
|
||||
def _load(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def padding(self, sen, padding_idx, dtype=np.int64):
|
||||
"""Padding <pad> to sentence."""
|
||||
if sen.shape[0] > self._max_sentence_len:
|
||||
return None
|
||||
new_sen = np.array([padding_idx] * self._max_sentence_len,
|
||||
dtype=dtype)
|
||||
new_sen[:sen.shape[0]] = sen[:]
|
||||
return new_sen
|
||||
|
||||
def write_to_mindrecord(self, path, shard_num=1, desc=""):
|
||||
"""
|
||||
Write mindrecord file.
|
||||
|
||||
Args:
|
||||
path (str): File path.
|
||||
shard_num (int): Shard num.
|
||||
desc (str): Description.
|
||||
"""
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.abspath(path)
|
||||
|
||||
writer = FileWriter(file_name=path, shard_num=shard_num)
|
||||
writer.add_schema(self._SCHEMA, desc)
|
||||
if not self._examples:
|
||||
self._load()
|
||||
|
||||
writer.write_raw_data(self._examples)
|
||||
writer.commit()
|
||||
print(f"| Wrote to {path}.")
|
||||
|
||||
def write_to_tfrecord(self, path, shard_num=1):
|
||||
"""
|
||||
Write to tfrecord.
|
||||
|
||||
Args:
|
||||
path (str): Output file path.
|
||||
shard_num (int): Shard num.
|
||||
"""
|
||||
import tensorflow as tf
|
||||
if not os.path.isabs(path):
|
||||
path = os.path.abspath(path)
|
||||
output_files = []
|
||||
for i in range(shard_num):
|
||||
output_file = path + "-%03d-of-%03d" % (i + 1, shard_num)
|
||||
output_files.append(output_file)
|
||||
# create writers
|
||||
writers = []
|
||||
for output_file in output_files:
|
||||
writers.append(tf.io.TFRecordWriter(output_file))
|
||||
|
||||
if not self._examples:
|
||||
self._load()
|
||||
|
||||
# create feature
|
||||
features = collections.OrderedDict()
|
||||
for example in self._examples:
|
||||
for key in example:
|
||||
features[key] = tf.train.Feature(int64_list=tf.train.Int64List(value=example[key].tolist()))
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
for writer in writers:
|
||||
writer.write(tf_example.SerializeToString())
|
||||
for writer in writers:
|
||||
writer.close()
|
||||
for p in output_files:
|
||||
print(f" | Write to {p}.")
|
||||
|
||||
def _add_example(self, example):
|
||||
self._examples.append(example)
|
@ -0,0 +1,142 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Bilingual data loader."""
|
||||
import numpy as np
|
||||
|
||||
from src.utils import Dictionary
|
||||
from .base import DataLoader
|
||||
from ..language_model.base import LanguageModel
|
||||
from ..language_model.noise_channel_language_model import NoiseChannelLanguageModel
|
||||
|
||||
|
||||
class BiLingualDataLoader(DataLoader):
|
||||
"""Loader for bilingual data."""
|
||||
|
||||
def __init__(self, src_filepath: str, tgt_filepath: str,
|
||||
src_dict: Dictionary, tgt_dict: Dictionary,
|
||||
src_lang: str, tgt_lang: str,
|
||||
language_model: LanguageModel = NoiseChannelLanguageModel(add_noise_prob=0),
|
||||
max_sen_len=66,
|
||||
merge_dict=True):
|
||||
super(BiLingualDataLoader, self).__init__(max_sen_len)
|
||||
self._src_filepath = src_filepath
|
||||
self._tgt_filepath = tgt_filepath
|
||||
self._src_dict = src_dict
|
||||
self._tgt_dict = tgt_dict
|
||||
self.src_lang = src_lang
|
||||
self.tgt_lang = tgt_lang
|
||||
self._lm = language_model
|
||||
self.max_sen_len = max_sen_len
|
||||
self.share_dict = merge_dict
|
||||
self._merge_dict()
|
||||
|
||||
def _merge_dict(self):
|
||||
if self.share_dict:
|
||||
merged_dict = self._src_dict.merge_dict(self._tgt_dict,
|
||||
new_dict=True)
|
||||
self._src_dict = merged_dict
|
||||
self._tgt_dict = merged_dict
|
||||
|
||||
@property
|
||||
def src_dict(self):
|
||||
return self._src_dict
|
||||
|
||||
@property
|
||||
def tgt_dict(self):
|
||||
return self._tgt_dict
|
||||
|
||||
def _load(self):
|
||||
_min_len = 9999999999
|
||||
_max_len = 0
|
||||
unk_count = 0
|
||||
tokens_count = 0
|
||||
count = 0
|
||||
with open(self._src_filepath, "r") as _src_file:
|
||||
print(f" | Processing corpus {self._src_filepath}.")
|
||||
print(f" | Processing corpus {self._tgt_filepath}.")
|
||||
with open(self._tgt_filepath, "r") as _tgt_file:
|
||||
_min, _max = 9999999, -1
|
||||
for _, _pair in enumerate(zip(_src_file, _tgt_file)):
|
||||
src_tokens = [
|
||||
self._src_dict.index(t)
|
||||
for t in _pair[0].strip().split(" ") if t
|
||||
]
|
||||
tgt_tokens = [
|
||||
self._tgt_dict.index(t)
|
||||
for t in _pair[1].strip().split(" ") if t
|
||||
]
|
||||
src_tokens.append(self._src_dict.eos_index)
|
||||
tgt_tokens.append(self._tgt_dict.eos_index)
|
||||
opt = self._lm.emit(
|
||||
sentence=np.array(src_tokens, dtype=np.int64),
|
||||
target=np.array(tgt_tokens, dtype=np.int64),
|
||||
mask_symbol_idx=self._src_dict.mask_index,
|
||||
bos_symbol_idx=self._tgt_dict.bos_index
|
||||
)
|
||||
src_len = opt["sentence_length"]
|
||||
tgt_len = opt["tgt_sen_length"]
|
||||
|
||||
_min_len = min(_min_len, opt["sentence_length"], opt["tgt_sen_length"])
|
||||
_max_len = max(_max_len, opt["sentence_length"], opt["tgt_sen_length"])
|
||||
|
||||
if src_len > self.max_sen_len or tgt_len > self.max_sen_len:
|
||||
continue
|
||||
|
||||
src_padding = np.zeros(shape=self.max_sen_len, dtype=np.int64)
|
||||
tgt_padding = np.zeros(shape=self.max_sen_len, dtype=np.int64)
|
||||
for i in range(src_len):
|
||||
src_padding[i] = 1
|
||||
for j in range(tgt_len):
|
||||
tgt_padding[j] = 1
|
||||
|
||||
tokens_count += opt["encoder_input"].shape[0]
|
||||
tokens_count += opt["decoder_input"].shape[0]
|
||||
tokens_count += opt["decoder_output"].shape[0]
|
||||
unk_count += np.where(opt["encoder_input"] == self._src_dict.unk_index)[0].shape[0]
|
||||
unk_count += np.where(opt["decoder_input"] == self._src_dict.unk_index)[0].shape[0]
|
||||
unk_count += np.where(opt["decoder_output"] == self._src_dict.unk_index)[0].shape[0]
|
||||
|
||||
encoder_input = self.padding(opt["encoder_input"],
|
||||
self._src_dict.padding_index)
|
||||
decoder_input = self.padding(opt["decoder_input"],
|
||||
self._tgt_dict.padding_index)
|
||||
decoder_output = self.padding(opt["decoder_output"],
|
||||
self._tgt_dict.padding_index)
|
||||
if encoder_input is None or decoder_input is None or decoder_output is None:
|
||||
continue
|
||||
|
||||
_min = np.min([np.min(encoder_input),
|
||||
np.min(decoder_input),
|
||||
np.min(decoder_output), _min])
|
||||
_max = np.max([np.max(encoder_input),
|
||||
np.max(decoder_input),
|
||||
np.max(decoder_output), _max])
|
||||
|
||||
example = {
|
||||
"src_padding": src_padding,
|
||||
"tgt_padding": tgt_padding,
|
||||
"src": encoder_input,
|
||||
"prev_opt": decoder_input,
|
||||
"prev_padding": tgt_padding,
|
||||
"target": decoder_output
|
||||
}
|
||||
self._add_example(example)
|
||||
count += 1
|
||||
|
||||
print(f" | Shortest len = {_min_len}.")
|
||||
print(f" | Longest len = {_max_len}.")
|
||||
print(f" | Total sen = {count}.")
|
||||
print(f" | Total token num={tokens_count}, "
|
||||
f"{unk_count / tokens_count * 100}% replaced by <unk>.")
|
@ -0,0 +1,121 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset loader to feed into model."""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as deC
|
||||
|
||||
|
||||
def _load_dataset(input_files, batch_size, epoch_count=1,
|
||||
sink_mode=False, sink_step=1, rank_size=1, rank_id=0, shuffle=True):
|
||||
"""
|
||||
Load dataset according to passed in params.
|
||||
|
||||
Args:
|
||||
input_files (list): Data files.
|
||||
batch_size (int): Batch size.
|
||||
epoch_count (int): Epoch count.
|
||||
sink_mode (bool): Whether enable sink mode.
|
||||
sink_step (int): Step to sink.
|
||||
rank_size (int): Rank size.
|
||||
rank_id (int): Rank id.
|
||||
shuffle (bool): Whether shuffle dataset.
|
||||
|
||||
Returns:
|
||||
Dataset, dataset instance.
|
||||
"""
|
||||
if not input_files:
|
||||
raise FileNotFoundError("Require at least one dataset.")
|
||||
|
||||
if not (schema_file and
|
||||
os.path.exists(schema_file)
|
||||
and os.path.isfile(schema_file)
|
||||
and os.path.basename(schema_file).endswith(".json")):
|
||||
raise FileNotFoundError("`dataset_schema` must be a existed json file.")
|
||||
|
||||
if not isinstance(sink_mode, bool):
|
||||
raise ValueError("`sink` must be type of bool.")
|
||||
|
||||
for datafile in input_files:
|
||||
print(f" | Loading {datafile}.")
|
||||
|
||||
ds = de.TFRecordDataset(
|
||||
input_files,
|
||||
columns_list=[
|
||||
"src", "src_padding",
|
||||
"prev_opt", "prev_padding",
|
||||
"target", "tgt_padding"
|
||||
],
|
||||
shuffle=shuffle, num_shards=rank_size, shard_id=rank_id,
|
||||
shard_equal_rows=True, num_parallel_workers=8)
|
||||
|
||||
ori_dataset_size = ds.get_dataset_size()
|
||||
print(f" | Dataset size: {ori_dataset_size}.")
|
||||
repeat_count = epoch_count
|
||||
if sink_mode:
|
||||
ds.set_dataset_size(sink_step * batch_size)
|
||||
repeat_count = epoch_count * ori_dataset_size // ds.get_dataset_size()
|
||||
|
||||
type_cast_op = deC.TypeCast(mstype.int32)
|
||||
ds = ds.map(input_columns="src", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="src_padding", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="prev_opt", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="prev_padding", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="target", operations=type_cast_op)
|
||||
ds = ds.map(input_columns="tgt_padding", operations=type_cast_op)
|
||||
|
||||
ds = ds.rename(
|
||||
input_columns=["src",
|
||||
"src_padding",
|
||||
"prev_opt",
|
||||
"prev_padding",
|
||||
"target",
|
||||
"tgt_padding"],
|
||||
output_columns=["source_eos_ids",
|
||||
"source_eos_mask",
|
||||
"target_sos_ids",
|
||||
"target_sos_mask",
|
||||
"target_eos_ids",
|
||||
"target_eos_mask"]
|
||||
)
|
||||
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
ds = ds.repeat(repeat_count)
|
||||
|
||||
ds.channel_name = 'transformer'
|
||||
return ds
|
||||
|
||||
|
||||
def load_dataset(data_files: list, batch_size: int, epoch_count: int,
|
||||
sink_mode: bool, sink_step: int = 1, rank_size: int = 1, rank_id: int = 0, shuffle=True):
|
||||
"""
|
||||
Load dataset.
|
||||
|
||||
Args:
|
||||
data_files (list): Data files.
|
||||
batch_size (int): Batch size.
|
||||
epoch_count (int): Epoch count.
|
||||
sink_mode (bool): Whether enable sink mode.
|
||||
sink_step (int): Step to sink.
|
||||
rank_size (int): Rank size.
|
||||
rank_id (int): Rank id.
|
||||
shuffle (bool): Whether shuffle dataset.
|
||||
|
||||
Returns:
|
||||
Dataset, dataset instance.
|
||||
"""
|
||||
return _load_dataset(data_files, batch_size, epoch_count, sink_mode,
|
||||
sink_step, rank_size, rank_id, shuffle=shuffle)
|
@ -0,0 +1,109 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Mono data loader."""
|
||||
import numpy as np
|
||||
|
||||
from src.utils import Dictionary
|
||||
|
||||
from .base import DataLoader
|
||||
from .schema import SCHEMA
|
||||
from ..language_model.base import LanguageModel
|
||||
from ..language_model import LooseMaskedLanguageModel
|
||||
|
||||
|
||||
class MonoLingualDataLoader(DataLoader):
|
||||
"""Loader for monolingual data."""
|
||||
_SCHEMA = SCHEMA
|
||||
|
||||
def __init__(self, src_filepath: str, lang: str, dictionary: Dictionary,
|
||||
language_model: LanguageModel = LooseMaskedLanguageModel(mask_ratio=0.3),
|
||||
max_sen_len=66, min_sen_len=16):
|
||||
super(MonoLingualDataLoader, self).__init__(max_sen_len=max_sen_len)
|
||||
self._file_path = src_filepath
|
||||
self._lang = lang
|
||||
self._dictionary = dictionary
|
||||
self._lm = language_model
|
||||
self.max_sen_len = max_sen_len
|
||||
self.min_sen_len = min_sen_len
|
||||
|
||||
@property
|
||||
def dict(self):
|
||||
return self._dictionary
|
||||
|
||||
def generate_padding_mask(self, sentence, length, exclude_mask=False):
|
||||
"""Generate padding mask vector."""
|
||||
src_padding = np.zeros(shape=self.max_sen_len, dtype=np.int64)
|
||||
if exclude_mask:
|
||||
pos = np.where(sentence == self._dictionary.padding_index)[0]
|
||||
else:
|
||||
pos = np.where((sentence == self._dictionary.padding_index) | (sentence == self._dictionary.mask_index))[0]
|
||||
src_padding[0:length] = 1
|
||||
if pos.shape[0] != 0:
|
||||
src_padding[pos] = 0
|
||||
return src_padding
|
||||
|
||||
def _load(self):
|
||||
_min_len = 9999999999
|
||||
_max_len = 0
|
||||
count = 0
|
||||
with open(self._file_path, "r") as _file:
|
||||
print(f" | Processing corpus {self._file_path}.")
|
||||
for _, _line in enumerate(_file):
|
||||
tokens = [self._dictionary.index(t.replace(" ", ""))
|
||||
for t in _line.strip().split(" ") if t]
|
||||
# In mass code, it doesn't add <BOS> to sen.
|
||||
tokens.append(self._dictionary.eos_index)
|
||||
opt = self._lm.emit(sentence=np.array(tokens, dtype=np.int32),
|
||||
vocabulary=self._dictionary)
|
||||
|
||||
src_len = opt["sentence_length"]
|
||||
_min_len = min(_min_len, opt["sentence_length"], opt["tgt_sen_length"])
|
||||
_max_len = max(_max_len, opt["sentence_length"], opt["tgt_sen_length"])
|
||||
|
||||
if src_len > self.max_sen_len:
|
||||
continue
|
||||
if src_len < self.min_sen_len:
|
||||
continue
|
||||
|
||||
src_padding = self.generate_padding_mask(opt["encoder_input"],
|
||||
opt["sentence_length"],
|
||||
exclude_mask=False)
|
||||
tgt_padding = self.generate_padding_mask(opt["decoder_input"],
|
||||
opt["tgt_sen_length"],
|
||||
exclude_mask=True)
|
||||
|
||||
encoder_input = self.padding(opt["encoder_input"],
|
||||
self._dictionary.padding_index)
|
||||
decoder_input = self.padding(opt["decoder_input"],
|
||||
self._dictionary.padding_index)
|
||||
decoder_output = self.padding(opt["decoder_output"],
|
||||
self._dictionary.padding_index)
|
||||
if encoder_input is None or decoder_input is None or decoder_output is None:
|
||||
continue
|
||||
|
||||
example = {
|
||||
"src": encoder_input,
|
||||
"src_padding": src_padding,
|
||||
"prev_opt": decoder_input,
|
||||
"prev_padding": tgt_padding,
|
||||
"target": decoder_output,
|
||||
"tgt_padding": tgt_padding,
|
||||
}
|
||||
self._add_example(example)
|
||||
count += 1
|
||||
|
||||
print(f" | Shortest len = {_min_len}.")
|
||||
print(f" | Longest len = {_max_len}.")
|
||||
print(f" | Total sen = {count}.")
|
@ -0,0 +1,24 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Define schema of mindrecord."""
|
||||
|
||||
SCHEMA = {
|
||||
"src": {"type": "int64", "shape": [-1]},
|
||||
"src_padding": {"type": "int64", "shape": [-1]},
|
||||
"prev_opt": {"type": "int64", "shape": [-1]},
|
||||
"prev_padding": {"type": "int64", "shape": [-1]},
|
||||
"target": {"type": "int64", "shape": [-1]},
|
||||
"tgt_padding": {"type": "int64", "shape": [-1]},
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Language model."""
|
||||
from .noise_channel_language_model import NoiseChannelLanguageModel
|
||||
from .masked_language_model import MaskedLanguageModel
|
||||
from .loose_masked_language_model import LooseMaskedLanguageModel
|
||||
from .mass_language_model import MassLanguageModel
|
||||
|
||||
__all__ = [
|
||||
"LooseMaskedLanguageModel",
|
||||
"MassLanguageModel",
|
||||
"MaskedLanguageModel",
|
||||
"NoiseChannelLanguageModel"
|
||||
]
|
@ -0,0 +1,25 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Base language model."""
|
||||
|
||||
|
||||
class LanguageModel:
|
||||
"""Define base language model."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def emit(self, **kwargs):
|
||||
raise NotImplementedError
|
@ -0,0 +1,130 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Modified masked language model."""
|
||||
import numpy as np
|
||||
|
||||
from src.utils import Dictionary
|
||||
from .base import LanguageModel
|
||||
|
||||
|
||||
class LooseMaskedLanguageModel(LanguageModel):
|
||||
"""
|
||||
Modified mask operation on sentence.
|
||||
|
||||
If k is assigned, then mask sentence with length k.
|
||||
Otherwise, use mask_ratio.
|
||||
|
||||
Args:
|
||||
k (int): Length of fragment.
|
||||
mask_ratio (float): Mask ratio.
|
||||
"""
|
||||
|
||||
def __init__(self, k: int = None, mask_ratio=0.5,
|
||||
mask_all_prob=None):
|
||||
super(LooseMaskedLanguageModel, self).__init__()
|
||||
self.mask_ratio = mask_ratio
|
||||
self._k = k
|
||||
self._threshold = mask_all_prob
|
||||
|
||||
def emit(self, sentence: np.ndarray, vocabulary: Dictionary):
|
||||
"""
|
||||
Mask mono source sentence.
|
||||
|
||||
A sample used to train model is processed with following step:
|
||||
|
||||
encoder input (source): [x1, x2, x3, x4, x5, x6, x7, x8, </eos>]
|
||||
masked encoder input: [x1, x2, x3, _, _, _, x7, x8, </eos>]
|
||||
decoder input: [ -, x3, x4, x5]
|
||||
| | | |
|
||||
V V V V
|
||||
decoder output: [x3, x4, x5, x6]
|
||||
|
||||
Notes:
|
||||
A simple rule is made that source sentence starts without <BOS>
|
||||
but end with <EOS>.
|
||||
|
||||
Args:
|
||||
vocabulary (Dictionary): Vocabulary.
|
||||
sentence (np.ndarray): Raw sentence instance.
|
||||
|
||||
Returns:
|
||||
dict, an example.
|
||||
"""
|
||||
# If v=0, then u must equal to 0. [u, v)
|
||||
u, v = self._get_masked_interval(sentence.shape[0],
|
||||
self._k, self._threshold)
|
||||
|
||||
encoder_input = sentence.copy()
|
||||
right_shifted_sentence = np.concatenate(([vocabulary.bos_index], sentence[:-1]))
|
||||
|
||||
if u == 0:
|
||||
_len = v - u if v - u != 0 else sentence.shape[0]
|
||||
decoder_input = right_shifted_sentence[:_len]
|
||||
decoder_input[0] = vocabulary.mask_index
|
||||
decoder_output = sentence[:_len].copy()
|
||||
else:
|
||||
decoder_input = right_shifted_sentence[u - 1:v]
|
||||
decoder_input[0] = vocabulary.mask_index
|
||||
decoder_output = sentence[u - 1:v].copy()
|
||||
|
||||
if v == 0:
|
||||
decoder_input[:] = vocabulary.mask_index
|
||||
else:
|
||||
encoder_input[np.arange(start=u, stop=v)] = vocabulary.mask_index
|
||||
|
||||
if u != v and u > 1:
|
||||
padding = np.array([vocabulary.padding_index] * (u - 1), dtype=np.int32)
|
||||
decoder_input = np.concatenate((padding, decoder_input))
|
||||
decoder_output = np.concatenate((padding, decoder_output))
|
||||
|
||||
if decoder_input.shape[0] != decoder_output.shape[0]:
|
||||
raise ValueError("seq len must equal.")
|
||||
|
||||
return {
|
||||
"sentence_length": sentence.shape[0],
|
||||
"tgt_sen_length": decoder_output.shape[0],
|
||||
"encoder_input": encoder_input, # end with </eos>
|
||||
"decoder_input": decoder_input,
|
||||
"decoder_output": decoder_output # end with </eos>
|
||||
}
|
||||
|
||||
def _get_masked_interval(self, length, fix_length=None,
|
||||
threshold_to_mask_all=None):
|
||||
"""
|
||||
Generate a sequence length according to length and mask_ratio.
|
||||
|
||||
Args:
|
||||
length (int): Sequence length.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int], [start position, end position].
|
||||
"""
|
||||
# Can not larger than sequence length.
|
||||
# Mask_length belongs to [0, length].
|
||||
if fix_length is not None:
|
||||
interval_length = min(length, fix_length)
|
||||
else:
|
||||
interval_length = min(length, round(self.mask_ratio * length))
|
||||
|
||||
_magic = np.random.random()
|
||||
if threshold_to_mask_all is not None and _magic <= threshold_to_mask_all:
|
||||
return 0, length
|
||||
|
||||
# If not sequence to be masked, then return 0, 0.
|
||||
if interval_length == 0:
|
||||
return 0, 0
|
||||
# Otherwise, return start position and interval length.
|
||||
start_pos = np.random.randint(low=0, high=length - interval_length + 1)
|
||||
return start_pos, start_pos + interval_length
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue