|
|
|
@ -19,6 +19,7 @@ from __future__ import division
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
import ast
|
|
|
|
|
import collections
|
|
|
|
|
import logging
|
|
|
|
|
import numpy as np
|
|
|
|
@ -51,23 +52,23 @@ class SampleInstance():
|
|
|
|
|
return self.__str__()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_instance_to_file(writer, instance, tokenizer, max_seq_length, bucket):
|
|
|
|
|
"""Create files from `SampleInstance`s."""
|
|
|
|
|
def _find_bucket_length(num):
|
|
|
|
|
def get_instance_features(instance, tokenizer, max_seq_length, bucket):
|
|
|
|
|
"""Get features from `SampleInstance`s."""
|
|
|
|
|
def _find_bucket_length(source_tokens, target_tokens):
|
|
|
|
|
source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
|
|
|
|
|
target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
|
|
|
|
|
num = max(len(source_ids), len(target_ids))
|
|
|
|
|
assert num <= bucket[-1]
|
|
|
|
|
|
|
|
|
|
for index in range(1, len(bucket)):
|
|
|
|
|
if bucket[index - 1] < num <= bucket[index]:
|
|
|
|
|
return bucket[index]
|
|
|
|
|
return bucket[0]
|
|
|
|
|
|
|
|
|
|
def _convert_ids_and_mask(input_tokens):
|
|
|
|
|
def _convert_ids_and_mask(input_tokens, seq_max_bucket_length):
|
|
|
|
|
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
|
|
|
|
input_mask = [1] * len(input_ids)
|
|
|
|
|
assert len(input_ids) <= max_seq_length
|
|
|
|
|
|
|
|
|
|
seq_max_bucket_length = _find_bucket_length(len(input_ids))
|
|
|
|
|
|
|
|
|
|
while len(input_ids) < seq_max_bucket_length:
|
|
|
|
|
input_ids.append(0)
|
|
|
|
|
input_mask.append(0)
|
|
|
|
@ -77,10 +78,11 @@ def write_instance_to_file(writer, instance, tokenizer, max_seq_length, bucket):
|
|
|
|
|
|
|
|
|
|
return input_ids, input_mask
|
|
|
|
|
|
|
|
|
|
source_sos_ids, source_sos_mask = _convert_ids_and_mask(instance.source_sos_tokens)
|
|
|
|
|
source_eos_ids, source_eos_mask = _convert_ids_and_mask(instance.source_eos_tokens)
|
|
|
|
|
target_sos_ids, target_sos_mask = _convert_ids_and_mask(instance.target_sos_tokens)
|
|
|
|
|
target_eos_ids, target_eos_mask = _convert_ids_and_mask(instance.target_eos_tokens)
|
|
|
|
|
seq_max_bucket_length = _find_bucket_length(instance.source_sos_tokens, instance.target_sos_tokens)
|
|
|
|
|
source_sos_ids, source_sos_mask = _convert_ids_and_mask(instance.source_sos_tokens, seq_max_bucket_length)
|
|
|
|
|
source_eos_ids, source_eos_mask = _convert_ids_and_mask(instance.source_eos_tokens, seq_max_bucket_length)
|
|
|
|
|
target_sos_ids, target_sos_mask = _convert_ids_and_mask(instance.target_sos_tokens, seq_max_bucket_length)
|
|
|
|
|
target_eos_ids, target_eos_mask = _convert_ids_and_mask(instance.target_eos_tokens, seq_max_bucket_length)
|
|
|
|
|
|
|
|
|
|
features = collections.OrderedDict()
|
|
|
|
|
features["source_sos_ids"] = np.asarray(source_sos_ids)
|
|
|
|
@ -92,8 +94,7 @@ def write_instance_to_file(writer, instance, tokenizer, max_seq_length, bucket):
|
|
|
|
|
features["target_eos_ids"] = np.asarray(target_eos_ids)
|
|
|
|
|
features["target_eos_mask"] = np.asarray(target_eos_mask)
|
|
|
|
|
|
|
|
|
|
writer.write_raw_data([features])
|
|
|
|
|
return features
|
|
|
|
|
return features, seq_max_bucket_length
|
|
|
|
|
|
|
|
|
|
def create_training_instance(source_words, target_words, max_seq_length, clip_to_max_len):
|
|
|
|
|
"""Creates `SampleInstance`s for a single sentence pair."""
|
|
|
|
@ -131,7 +132,8 @@ def main():
|
|
|
|
|
parser.add_argument("--clip_to_max_len", type=bool, default=False,
|
|
|
|
|
help='clip sequences to maximum sequence length.')
|
|
|
|
|
parser.add_argument("--max_seq_length", type=int, default=128, help='Maximum sequence length.')
|
|
|
|
|
parser.add_argument("--bucket", type=list, default=[16, 32, 48, 64, 128], help='bucket sequence length')
|
|
|
|
|
parser.add_argument("--bucket", type=ast.literal_eval, default=[16, 32, 48, 64, 128],
|
|
|
|
|
help='bucket sequence length')
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
@ -141,29 +143,21 @@ def main():
|
|
|
|
|
for input_pattern in args.input_file.split(","):
|
|
|
|
|
input_files.append(input_pattern)
|
|
|
|
|
|
|
|
|
|
logging.info("*** Reading from input files ***")
|
|
|
|
|
logging.info("*** Read from input files ***")
|
|
|
|
|
for input_file in input_files:
|
|
|
|
|
logging.info(" %s", input_file)
|
|
|
|
|
|
|
|
|
|
output_file = args.output_file
|
|
|
|
|
logging.info("*** Writing to output files ***")
|
|
|
|
|
logging.info("*** Write to output files ***")
|
|
|
|
|
logging.info(" %s", output_file)
|
|
|
|
|
|
|
|
|
|
writer = FileWriter(output_file, args.num_splits)
|
|
|
|
|
data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"source_sos_mask": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"source_eos_ids": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"source_eos_mask": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"target_sos_ids": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"target_sos_mask": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"target_eos_ids": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"target_eos_mask": {"type": "int64", "shape": [-1]}
|
|
|
|
|
}
|
|
|
|
|
writer.add_schema(data_schema, "tranformer hisi")
|
|
|
|
|
|
|
|
|
|
total_written = 0
|
|
|
|
|
total_read = 0
|
|
|
|
|
|
|
|
|
|
feature_dict = {}
|
|
|
|
|
for i in args.bucket:
|
|
|
|
|
feature_dict[i] = []
|
|
|
|
|
|
|
|
|
|
for input_file in input_files:
|
|
|
|
|
logging.info("*** Reading from %s ***", input_file)
|
|
|
|
|
with open(input_file, "r") as reader:
|
|
|
|
@ -174,7 +168,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
total_read += 1
|
|
|
|
|
if total_read % 100000 == 0:
|
|
|
|
|
logging.info("%d ...", total_read)
|
|
|
|
|
logging.info("Read %d ...", total_read)
|
|
|
|
|
|
|
|
|
|
source_line, target_line = line.strip().split("\t")
|
|
|
|
|
source_tokens = tokenizer.tokenize(source_line)
|
|
|
|
@ -189,10 +183,13 @@ def main():
|
|
|
|
|
if instance is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
features = write_instance_to_file(writer, instance, tokenizer, args.max_seq_length, args.bucket)
|
|
|
|
|
total_written += 1
|
|
|
|
|
features, seq_max_bucket_length = get_instance_features(instance, tokenizer, args.max_seq_length,
|
|
|
|
|
args.bucket)
|
|
|
|
|
for key in feature_dict:
|
|
|
|
|
if key == seq_max_bucket_length:
|
|
|
|
|
feature_dict[key].append(features)
|
|
|
|
|
|
|
|
|
|
if total_written <= 20:
|
|
|
|
|
if total_read <= 10:
|
|
|
|
|
logging.info("*** Example ***")
|
|
|
|
|
logging.info("source tokens: %s", " ".join(
|
|
|
|
|
[tokenization.convert_to_printable(x) for x in instance.source_eos_tokens]))
|
|
|
|
@ -203,9 +200,33 @@ def main():
|
|
|
|
|
feature = features[feature_name]
|
|
|
|
|
logging.info("%s: %s", feature_name, feature)
|
|
|
|
|
|
|
|
|
|
writer.commit()
|
|
|
|
|
for i in args.bucket:
|
|
|
|
|
if args.num_splits == 1:
|
|
|
|
|
output_file_name = output_file
|
|
|
|
|
else:
|
|
|
|
|
output_file_name = output_file + '_' + str(i) + '_'
|
|
|
|
|
writer = FileWriter(output_file_name, args.num_splits)
|
|
|
|
|
data_schema = {"source_sos_ids": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"source_sos_mask": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"source_eos_ids": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"source_eos_mask": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"target_sos_ids": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"target_sos_mask": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"target_eos_ids": {"type": "int64", "shape": [-1]},
|
|
|
|
|
"target_eos_mask": {"type": "int64", "shape": [-1]}
|
|
|
|
|
}
|
|
|
|
|
writer.add_schema(data_schema, "tranformer")
|
|
|
|
|
features_ = feature_dict[i]
|
|
|
|
|
logging.info("Bucket length %d has %d samples, start writing...", i, len(features_))
|
|
|
|
|
|
|
|
|
|
for item in features_:
|
|
|
|
|
writer.write_raw_data([item])
|
|
|
|
|
total_written += 1
|
|
|
|
|
writer.commit()
|
|
|
|
|
|
|
|
|
|
logging.info("Wrote %d total instances", total_written)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
main()
|
|
|
|
|