You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.9 KiB
87 lines
3.9 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""dataset preprocess"""
|
|
|
|
import argparse
|
|
|
|
from src.utils.data_preprocess import lambada_dataset_preprocess
|
|
from src.utils.data_preprocess import cbt_dataset_preprocess
|
|
from src.utils.data_preprocess import wikitext_dataset_preprocess
|
|
from src.utils.data_preprocess import ptb_dataset_preprocess
|
|
from src.utils.data_preprocess import onebw_dataset_preprocess
|
|
from src.utils.data_preprocess import coqa_dataset_preprocess
|
|
from src.utils.data_preprocess import wmt14_en_fr_preprocess
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="All Task dataset preprocessing")
|
|
parser.add_argument("--task", type=str, default="translation",
|
|
help="The GPT-2 downstream task, including [LanguageModeling, CBT, Translation, Lambada"
|
|
"Summarization, ReadingComprehension]")
|
|
parser.add_argument("--input_file", type=str, default="",
|
|
help="The raw dataset path. ")
|
|
parser.add_argument("--dataset", type=str, default="onebw",
|
|
help="The name of dataset which should be processed, only for LanguageModeling task.")
|
|
parser.add_argument("--output_file", type=str, default="",
|
|
help="The output dataset path after preprocessing.")
|
|
parser.add_argument("--condition", type=str, default="test",
|
|
help="Process train or test dataset, including [train, test], only for 1BW and "
|
|
"CNN & DailyMail dataset.")
|
|
args_opt = parser.parse_args()
|
|
|
|
task = args_opt.task
|
|
condition = args_opt.condition
|
|
dataset = args_opt.dataset
|
|
input_file = args_opt.input_file
|
|
output_file = args_opt.output_file
|
|
|
|
if task.lower() == "languagemodeling":
|
|
print("Start processing Language Modeling dataset ...")
|
|
if dataset.lower() == "wikitext2" or dataset.lower() == "wikitext103":
|
|
wikitext_dataset_preprocess(input_file=input_file, output_file=output_file)
|
|
elif dataset.lower() == "ptb":
|
|
ptb_dataset_preprocess(input_file=input_file, output_file=output_file)
|
|
elif dataset.lower() == "onebw":
|
|
onebw_dataset_preprocess(condition, input_file=input_file, output_file=output_file)
|
|
else:
|
|
raise ValueError("Only support wikitext2, wikitext103, ptb, onebw dataset")
|
|
|
|
elif task.lower() == "lambada":
|
|
print("Start processing Lambada dataset ...")
|
|
lambada_dataset_preprocess(input_file=input_file, output_file=output_file)
|
|
|
|
elif task.lower() == "cbt":
|
|
print("Start processing CBT dataset ...")
|
|
cbt_dataset_preprocess(input_file=input_file, output_file=output_file)
|
|
|
|
elif task.lower() == "readingcomprehension":
|
|
print("Start processing ReadingComprehension dataset ...")
|
|
coqa_dataset_preprocess(input_file=input_file, output_file=output_file)
|
|
|
|
elif task.lower() == "summarization":
|
|
print("Start processing Summarization dataset ...")
|
|
|
|
elif task.lower() == "translation":
|
|
print("Start processing Translation dataset ...")
|
|
wmt14_en_fr_preprocess(input_file=input_file, output_file=output_file)
|
|
|
|
else:
|
|
raise ValueError("Only support Language Modeling, CBT, Translation, Lambada, "
|
|
"Summarization, Reading Comprehension task.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|