!8430 [MD] transform dataset to mindrecord in preprocess of gpt
From: @liyong126 Reviewed-by: Signed-off-by:pull/8430/MERGE
commit
848fb9b554
@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the scipt as: "
|
||||
echo "bash script/pre_process.sh \"INPUT_GLOB\" DATASET_TYPE OUTPUT_FILE"
|
||||
echo "for example: bash script/pre_process.sh \"dataset/*.output\" openwebtext ./output/openwebtext.mindrecord"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
INPUT_GLOB=$1
|
||||
DATASET_TYPE=$2
|
||||
OUTPUT_FILE=$3
|
||||
|
||||
python ./src/pre_process.py \
|
||||
--input_glob=$INPUT_GLOB \
|
||||
--dataset_type=$DATASET_TYPE \
|
||||
--output_file=$OUTPUT_FILE \
|
||||
--file_partition=4
|
@ -0,0 +1,216 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
transform wikitext-2, wikitext-103, lambada, openwebtext dataset to mindrecord.
|
||||
"""
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from multiprocessing import Pool, current_process
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from transformers import GPT2Tokenizer
|
||||
except ModuleNotFoundError:
|
||||
print("module 'transformers' not installed.")
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
|
||||
EOT = 50256 # id of endoftext
|
||||
SEQ_LEN = 1025 # the length of sample
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
|
||||
|
||||
def chunks(lst, n):
|
||||
""" yield n sized chunks from list"""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i:i+n]
|
||||
|
||||
|
||||
def package_file(it, n):
|
||||
""" package multiple files"""
|
||||
stop = False
|
||||
while not stop:
|
||||
batch = []
|
||||
for _ in range(n):
|
||||
try:
|
||||
batch.append(next(it))
|
||||
except StopIteration:
|
||||
stop = True
|
||||
if not batch:
|
||||
break
|
||||
yield batch
|
||||
|
||||
|
||||
def clean_wikitext(string):
|
||||
""" cleaning wikitext dataset"""
|
||||
# contractions
|
||||
string = string.replace("s '", "s'")
|
||||
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
|
||||
# number separators
|
||||
string = string.replace(" @-@ ", "-")
|
||||
string = string.replace(" @,@ ", ",")
|
||||
string = string.replace(" @.@ ", ".")
|
||||
# punctuation
|
||||
string = string.replace(" : ", ": ")
|
||||
string = string.replace(" ; ", "; ")
|
||||
string = string.replace(" . ", ". ")
|
||||
string = string.replace(" ! ", "! ")
|
||||
string = string.replace(" ? ", "? ")
|
||||
string = string.replace(" , ", ", ")
|
||||
# double brackets
|
||||
string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string)
|
||||
string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string)
|
||||
string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string)
|
||||
string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string)
|
||||
string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string)
|
||||
# miscellaneous
|
||||
string = string.replace("= = = =", "====")
|
||||
string = string.replace("= = =", "===")
|
||||
string = string.replace("= =", "==")
|
||||
string = string.replace(" "+chr(176)+" ", chr(176))
|
||||
string = string.replace(" \n", "\n")
|
||||
string = string.replace("\n ", "\n")
|
||||
string = string.replace(" N ", " 1 ")
|
||||
string = string.replace(" 's", "'s")
|
||||
return string
|
||||
|
||||
|
||||
def tokenize_openwebtext(iterator):
|
||||
""" tokenize openwebtext dataset"""
|
||||
for file_path in iterator:
|
||||
if os.path.getsize(file_path) == 0:
|
||||
continue
|
||||
content = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for para in f.read().split("\n\n"):
|
||||
if para:
|
||||
tokenized_text = tokenizer.tokenize(para)
|
||||
content += tokenizer.convert_tokens_to_ids(tokenized_text) + [
|
||||
EOT]
|
||||
for chunk in chunks(content, SEQ_LEN):
|
||||
sample = {}
|
||||
if len(chunk) == SEQ_LEN:
|
||||
sample['input_ids'] = np.array(chunk, dtype=np.int32)
|
||||
yield sample
|
||||
|
||||
|
||||
def tokenize_wiki(file_path):
|
||||
"""tokenize wikitext-2/wikitext-103 dataset"""
|
||||
content = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for para in clean_wikitext(f.read()).split("\n\n"):
|
||||
if para and para.strip().startswith('=') is False:
|
||||
tokenized_text = tokenizer.tokenize(para)
|
||||
content += tokenizer.convert_tokens_to_ids(tokenized_text) + [
|
||||
EOT]
|
||||
for chunk in chunks(content, SEQ_LEN):
|
||||
sample = {}
|
||||
if len(chunk) == SEQ_LEN:
|
||||
sample['input_ids'] = np.array(chunk, dtype=np.int32)
|
||||
yield sample
|
||||
|
||||
|
||||
def tokenize_lambada(file_path):
|
||||
"""tokenize lambada dataset"""
|
||||
content = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f.readlines():
|
||||
para = json.loads(line)['text'].replace(
|
||||
"“", '""').replace("”", '"').strip().strip(".")
|
||||
tokenized_text = tokenizer.tokenize(para)
|
||||
content += tokenizer.convert_tokens_to_ids(tokenized_text) + [EOT]
|
||||
for chunk in chunks(content, SEQ_LEN):
|
||||
sample = {}
|
||||
if len(chunk) == SEQ_LEN:
|
||||
sample['input_ids'] = np.array(chunk, dtype=np.int32)
|
||||
yield sample
|
||||
|
||||
|
||||
def task_unit(iterator, parallel_writer=True):
|
||||
"""task for each process"""
|
||||
p = current_process()
|
||||
index = p.pid if p.pid else 0
|
||||
|
||||
item_iter = tokenize_openwebtext(iterator)
|
||||
batch_size = 1024 # size of write batch
|
||||
count = 0
|
||||
while True:
|
||||
data_batch = []
|
||||
try:
|
||||
for _ in range(batch_size):
|
||||
data_batch.append(next(item_iter))
|
||||
count += 1
|
||||
writer.write_raw_data(data_batch, parallel_writer=parallel_writer)
|
||||
print("Process {} transformed {} records.".format(
|
||||
index, count))
|
||||
except StopIteration:
|
||||
if data_batch:
|
||||
writer.write_raw_data(data_batch,
|
||||
parallel_writer=parallel_writer)
|
||||
print("Process {} transformed {} records.".format(
|
||||
index, count))
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_type', type=str, default='openwebtext')
|
||||
parser.add_argument('--input_glob', type=str, default='*.txt')
|
||||
parser.add_argument('--output_file', type=str,
|
||||
default='./output/openweb_mindrecord')
|
||||
parser.add_argument('--file_partition', type=int, default=1)
|
||||
parser.add_argument('--file_batch_size', type=int, default=1024)
|
||||
parser.add_argument('--num_process', type=int, default=64)
|
||||
|
||||
args = parser.parse_args()
|
||||
###
|
||||
out_dir, out_file = os.path.split(os.path.abspath(args.output_file))
|
||||
if not os.path.exists(out_dir):
|
||||
os.mkdir(out_dir)
|
||||
schema = {"input_ids": {"type": "int32", "shape": [-1]},}
|
||||
writer = FileWriter(file_name=args.output_file,
|
||||
shard_num=args.file_partition)
|
||||
writer.add_schema(schema, args.dataset_type)
|
||||
writer.open_and_set_header()
|
||||
###
|
||||
transforms_count = 0
|
||||
if args.dataset_type == 'wiki':
|
||||
for x in tokenize_wiki(args.input_glob):
|
||||
transforms_count += 1
|
||||
writer.write_raw_data([x])
|
||||
print("Transformed {} records.".format(transforms_count))
|
||||
elif args.dataset_type == 'lambada':
|
||||
for x in tokenize_lambada(args.input_glob):
|
||||
transforms_count += 1
|
||||
writer.write_raw_data([x])
|
||||
print("Transformed {} records.".format(transforms_count))
|
||||
elif args.dataset_type == 'openwebtext':
|
||||
file_iter = glob.iglob(args.input_glob)
|
||||
with Pool(processes=args.num_process) as pool:
|
||||
pool.map(task_unit, package_file(file_iter, args.file_batch_size))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Not support dataset type: {}".format(args.dataset_type))
|
||||
|
||||
writer.commit()
|
||||
out_file = args.output_file
|
||||
if args.file_partition > 1:
|
||||
out_file += '0'
|
||||
print("Transform finished, output files refer: {}".format(out_file))
|
Loading…
Reference in new issue