!8430 [MD] transform dataset to mindrecord in preprocess of gpt

From: @liyong126
Reviewed-by: 
Signed-off-by:
pull/8430/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 848fb9b554

@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash script/pre_process.sh \"INPUT_GLOB\" DATASET_TYPE OUTPUT_FILE"
echo "for example: bash script/pre_process.sh \"dataset/*.output\" openwebtext ./output/openwebtext.mindrecord"
echo "=============================================================================================================="
INPUT_GLOB=$1
DATASET_TYPE=$2
OUTPUT_FILE=$3
python ./src/pre_process.py \
--input_glob=$INPUT_GLOB \
--dataset_type=$DATASET_TYPE \
--output_file=$OUTPUT_FILE \
--file_partition=4

@ -0,0 +1,216 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
transform wikitext-2, wikitext-103, lambada, openwebtext dataset to mindrecord.
"""
import argparse
import glob
import json
import os
import re
from multiprocessing import Pool, current_process
import numpy as np
try:
from transformers import GPT2Tokenizer
except ModuleNotFoundError:
print("module 'transformers' not installed.")
from mindspore.mindrecord import FileWriter
EOT = 50256 # id of endoftext
SEQ_LEN = 1025 # the length of sample
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def chunks(lst, n):
""" yield n sized chunks from list"""
for i in range(0, len(lst), n):
yield lst[i:i+n]
def package_file(it, n):
""" package multiple files"""
stop = False
while not stop:
batch = []
for _ in range(n):
try:
batch.append(next(it))
except StopIteration:
stop = True
if not batch:
break
yield batch
def clean_wikitext(string):
""" cleaning wikitext dataset"""
# contractions
string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators
string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".")
# punctuation
string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ")
string = string.replace(" ! ", "! ")
string = string.replace(" ? ", "? ")
string = string.replace(" , ", ", ")
# double brackets
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
# miscellaneous
string = string.replace("= = = =", "====")
string = string.replace("= = =", "===")
string = string.replace("= =", "==")
string = string.replace(" "+chr(176)+" ", chr(176))
string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ")
string = string.replace(" 's", "'s")
return string
def tokenize_openwebtext(iterator):
""" tokenize openwebtext dataset"""
for file_path in iterator:
if os.path.getsize(file_path) == 0:
continue
content = []
with open(file_path, 'r', encoding='utf-8') as f:
for para in f.read().split("\n\n"):
if para:
tokenized_text = tokenizer.tokenize(para)
content += tokenizer.convert_tokens_to_ids(tokenized_text) + [
EOT]
for chunk in chunks(content, SEQ_LEN):
sample = {}
if len(chunk) == SEQ_LEN:
sample['input_ids'] = np.array(chunk, dtype=np.int32)
yield sample
def tokenize_wiki(file_path):
"""tokenize wikitext-2/wikitext-103 dataset"""
content = []
with open(file_path, 'r', encoding='utf-8') as f:
for para in clean_wikitext(f.read()).split("\n\n"):
if para and para.strip().startswith('=') is False:
tokenized_text = tokenizer.tokenize(para)
content += tokenizer.convert_tokens_to_ids(tokenized_text) + [
EOT]
for chunk in chunks(content, SEQ_LEN):
sample = {}
if len(chunk) == SEQ_LEN:
sample['input_ids'] = np.array(chunk, dtype=np.int32)
yield sample
def tokenize_lambada(file_path):
"""tokenize lambada dataset"""
content = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f.readlines():
para = json.loads(line)['text'].replace(
"", '""').replace("", '"').strip().strip(".")
tokenized_text = tokenizer.tokenize(para)
content += tokenizer.convert_tokens_to_ids(tokenized_text) + [EOT]
for chunk in chunks(content, SEQ_LEN):
sample = {}
if len(chunk) == SEQ_LEN:
sample['input_ids'] = np.array(chunk, dtype=np.int32)
yield sample
def task_unit(iterator, parallel_writer=True):
"""task for each process"""
p = current_process()
index = p.pid if p.pid else 0
item_iter = tokenize_openwebtext(iterator)
batch_size = 1024 # size of write batch
count = 0
while True:
data_batch = []
try:
for _ in range(batch_size):
data_batch.append(next(item_iter))
count += 1
writer.write_raw_data(data_batch, parallel_writer=parallel_writer)
print("Process {} transformed {} records.".format(
index, count))
except StopIteration:
if data_batch:
writer.write_raw_data(data_batch,
parallel_writer=parallel_writer)
print("Process {} transformed {} records.".format(
index, count))
break
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_type', type=str, default='openwebtext')
parser.add_argument('--input_glob', type=str, default='*.txt')
parser.add_argument('--output_file', type=str,
default='./output/openweb_mindrecord')
parser.add_argument('--file_partition', type=int, default=1)
parser.add_argument('--file_batch_size', type=int, default=1024)
parser.add_argument('--num_process', type=int, default=64)
args = parser.parse_args()
###
out_dir, out_file = os.path.split(os.path.abspath(args.output_file))
if not os.path.exists(out_dir):
os.mkdir(out_dir)
schema = {"input_ids": {"type": "int32", "shape": [-1]},}
writer = FileWriter(file_name=args.output_file,
shard_num=args.file_partition)
writer.add_schema(schema, args.dataset_type)
writer.open_and_set_header()
###
transforms_count = 0
if args.dataset_type == 'wiki':
for x in tokenize_wiki(args.input_glob):
transforms_count += 1
writer.write_raw_data([x])
print("Transformed {} records.".format(transforms_count))
elif args.dataset_type == 'lambada':
for x in tokenize_lambada(args.input_glob):
transforms_count += 1
writer.write_raw_data([x])
print("Transformed {} records.".format(transforms_count))
elif args.dataset_type == 'openwebtext':
file_iter = glob.iglob(args.input_glob)
with Pool(processes=args.num_process) as pool:
pool.map(task_unit, package_file(file_iter, args.file_batch_size))
else:
raise ValueError(
"Not support dataset type: {}".format(args.dataset_type))
writer.commit()
out_file = args.output_file
if args.file_partition > 1:
out_file += '0'
print("Transform finished, output files refer: {}".format(out_file))
Loading…
Cancel
Save