# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """Generate Gigaword dataset.""" import os import argparse from src.dataset import BiLingualDataLoader from src.language_model import NoiseChannelLanguageModel from src.utils import Dictionary parser = argparse.ArgumentParser(description='Create Gigaword fine-tune Dataset.') parser.add_argument("--train_src", type=str, default="", required=False, help="train dataset source file path.") parser.add_argument("--train_ref", type=str, default="", required=False, help="train dataset reference file path.") parser.add_argument("--test_src", type=str, default="", required=False, help="test dataset source file path.") parser.add_argument("--test_ref", type=str, default="", required=False, help="test dataset reference file path.") parser.add_argument("--noise_prob", type=float, default=0., required=False, help="add noise prob.") parser.add_argument("--existed_vocab", type=str, default="", required=False, help="existed vocab path.") parser.add_argument("--max_len", type=int, default=64, required=False, help="max length of sentences.") parser.add_argument("--output_folder", type=str, default="", required=True, help="dataset output path.") parser.add_argument("--format", type=str, default="tfrecord", required=False, help="dataset format.") if __name__ == '__main__': args, _ = parser.parse_known_args() vocab = Dictionary.load_from_persisted_dict(args.existed_vocab) if args.train_src and args.train_ref: train = BiLingualDataLoader( src_filepath=args.train_src, tgt_filepath=args.train_ref, src_dict=vocab, tgt_dict=vocab, src_lang="en", tgt_lang="en", language_model=NoiseChannelLanguageModel(add_noise_prob=args.noise_prob), max_sen_len=args.max_len ) if "tf" in args.format.lower(): train.write_to_tfrecord( path=os.path.join(args.output_folder, "gigaword_train_dataset.tfrecord") ) else: train.write_to_mindrecord( path=os.path.join(args.output_folder, "gigaword_train_dataset.mindrecord") ) if args.test_src and args.test_ref: test = BiLingualDataLoader( src_filepath=args.test_src, tgt_filepath=args.test_ref, src_dict=vocab, tgt_dict=vocab, src_lang="en", tgt_lang="en", language_model=NoiseChannelLanguageModel(add_noise_prob=0), max_sen_len=args.max_len ) if "tf" in args.format.lower(): test.write_to_tfrecord( path=os.path.join(args.output_folder, "gigaword_test_dataset.tfrecord") ) else: test.write_to_mindrecord( path=os.path.join(args.output_folder, "gigaword_test_dataset.mindrecord") ) print(f" | Vocabulary size: {vocab.size}.")