commit
ac5371b38f
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,141 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
CNN & DailyMail train dataset sampler
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import argparse
|
||||
from random import random
|
||||
|
||||
from src.utils.tokenization import Tokenizer
|
||||
|
||||
|
||||
def replace_split_word(read_path, output_path, tldr_str="TL;DR:", original_split='\t'):
|
||||
"""
|
||||
append tldr str
|
||||
"""
|
||||
with open(read_path, "r") as r, open(output_path, "a") as w:
|
||||
line = r.readline()
|
||||
while line:
|
||||
article = line[:line.find(original_split)] + ' ' + tldr_str + ' '
|
||||
ref = line[line.rfind(original_split) + 1:]
|
||||
w.write(article + ref)
|
||||
line = r.readline()
|
||||
|
||||
|
||||
def sample(read_path, out_path, threshold=1.0, max_items=0xFFFFFFF):
|
||||
"""
|
||||
sample function
|
||||
"""
|
||||
cnt = 0
|
||||
total_cnt = 0
|
||||
with open(read_path, "r") as r, open(out_path, "a") as w:
|
||||
line = r.readline()
|
||||
while line:
|
||||
total_cnt += 1
|
||||
if cnt >= max_items:
|
||||
break
|
||||
if random() > threshold:
|
||||
line = r.readline()
|
||||
continue
|
||||
w.write(line)
|
||||
if (cnt + 1) % 3000 == 0:
|
||||
print("Now Processed Samples: {}, total: {}".format(cnt, total_cnt))
|
||||
cnt += 1
|
||||
line = r.readline()
|
||||
|
||||
|
||||
def clip_article(input_path, out_path, hint, max_length):
|
||||
"""
|
||||
clip article that the sample (article + summary) exceed max_length
|
||||
"""
|
||||
tokenizer = Tokenizer()
|
||||
cnt = 0
|
||||
with open(input_path, "r") as r, open(out_path, "a+") as a:
|
||||
line = r.readline()
|
||||
while line:
|
||||
pos = line.rfind(hint)
|
||||
article = line[:pos]
|
||||
summary = line[pos:]
|
||||
if len(tokenizer.encode(line)) > max_length:
|
||||
l_article = tokenizer.encode(article)[:max_length - len(tokenizer.encode(summary))]
|
||||
article = tokenizer.decode(l_article) + " "
|
||||
if cnt % 1000 == 0:
|
||||
print(article + summary)
|
||||
print("==============================")
|
||||
cnt += 1
|
||||
a.write(article + summary)
|
||||
line = r.readline()
|
||||
|
||||
|
||||
def sampler_dataset():
|
||||
"""
|
||||
run CNN & DailyMail train dataset sampler
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_path", type=str, default="",
|
||||
help="input file path")
|
||||
parser.add_argument("--output_path", type=str, default="",
|
||||
help="out file path")
|
||||
parser.add_argument("--replace_hint", type=str, default="true")
|
||||
parser.add_argument("--sample", type=str, default="true",
|
||||
help="do sample? true or false")
|
||||
parser.add_argument("--max_length", type=int, default=1022,
|
||||
help="max seq_length of input_raw_dataset")
|
||||
parser.add_argument("--prob", type=float, default=0.25,
|
||||
help="sample rate")
|
||||
parser.add_argument("--max_items", type=int, default=10000,
|
||||
help="max number of document")
|
||||
parser.add_argument("--hint", type=str, default="TL:DR;",
|
||||
help="hint text")
|
||||
args = parser.parse_args()
|
||||
|
||||
# temp_files, one for storing inputs in every stage, the other for storing middle results.
|
||||
temp_file_input = sys.path[0] + '/temp_file1_by_sampler_py.txt'
|
||||
temp_file_proc = sys.path[0] + '/temp_file2_by_sampler_py.txt'
|
||||
|
||||
read_path = args.input_path
|
||||
output_path = args.output_path
|
||||
prob = args.prob
|
||||
max_items = args.max_items
|
||||
hint = args.hint
|
||||
max_length = args.max_length
|
||||
split_str = '\t'
|
||||
|
||||
shutil.copyfile(read_path, temp_file_input)
|
||||
clip_article(temp_file_input, temp_file_proc, hint=split_str, max_length=max_length)
|
||||
shutil.copyfile(temp_file_proc, temp_file_input)
|
||||
os.remove(temp_file_proc)
|
||||
|
||||
if args.replace_hint.lower() == "true":
|
||||
replace_split_word(temp_file_input, temp_file_proc, hint, split_str)
|
||||
shutil.copyfile(temp_file_proc, temp_file_input)
|
||||
os.remove(temp_file_proc)
|
||||
|
||||
if args.sample.lower() == "true":
|
||||
sample(temp_file_input, temp_file_proc, prob, max_items)
|
||||
shutil.copyfile(temp_file_proc, temp_file_input)
|
||||
os.remove(temp_file_proc)
|
||||
|
||||
shutil.copyfile(temp_file_input, output_path)
|
||||
os.remove(temp_file_input)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sampler_dataset()
|
@ -0,0 +1,67 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
Read weight using tensorflow
|
||||
to read the parameters of the gpt-2 pretrained model from tensorflow checkpoint
|
||||
and save them into npy files for mindspore to load.
|
||||
|
||||
*this script is based on the gpt-2 model downloaded from openai.*
|
||||
"""
|
||||
import argparse
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from .trans_dict import trans_dict_tf
|
||||
|
||||
|
||||
def read_weight(ckpt_path):
|
||||
"""
|
||||
read weight
|
||||
Args:
|
||||
ckpt_path: the path of tensorflow checkpoint
|
||||
"""
|
||||
# model path and model name
|
||||
init_vars = tf.train.list_variables(ckpt_path)
|
||||
# load the model parameters into vars
|
||||
save_param_num = 0
|
||||
|
||||
for name, _ in init_vars:
|
||||
array = tf.train.load_variable(ckpt_path, name)
|
||||
# By this you can understand the next step easily
|
||||
name = name[6:].replace(r"/", ".")
|
||||
# skip 'model/' and change var names to avoid path mistake
|
||||
if name not in trans_dict_tf.keys():
|
||||
print(name + " is not in this model")
|
||||
else:
|
||||
np.save(trans_dict_tf[name] + ".npy", array)
|
||||
save_param_num = save_param_num + 1
|
||||
# save the parameters by 'npy'
|
||||
|
||||
print("finished!")
|
||||
print("save {num} parameters.".format(num=save_param_num))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight")
|
||||
parser.add_argument("--ckpt_file_path", type=str, default="",
|
||||
help="The tensorflow GPT-2 model checkpoint file path")
|
||||
args_opt = parser.parse_args()
|
||||
ckpt_path = args_opt.ckpt_file_path
|
||||
read_weight(ckpt_path=ckpt_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,60 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
Save weight using mindspore, to load the parameters of gpt-2 model from npy file.
|
||||
npy files should be in the same path with this script. Otherwise you should change the path name of the script.
|
||||
"""
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
from .trans_dict import trans_dict_tf
|
||||
|
||||
|
||||
def trans_model_parameter(ckpt_name):
|
||||
"""
|
||||
transform model parameters
|
||||
Args:
|
||||
ckpt_name (str): the name of the transformed checkpoint.
|
||||
"""
|
||||
file_names = [name for name in os.listdir() if name.endswith(".npy")]
|
||||
# to find all file names with suffix '.npy' in the current path.
|
||||
new_params_list = []
|
||||
for file_name in file_names:
|
||||
var_name = file_name[:-4]
|
||||
param_dict = {"name": var_name, "data": Tensor(np.load(file_name))}
|
||||
if var_name in trans_dict_tf.values():
|
||||
new_params_list.append(param_dict)
|
||||
print(var_name+" has been saved")
|
||||
|
||||
save_checkpoint(new_params_list, ckpt_name)
|
||||
# to load the parameters from npy files and save them as mindspore checkpoint
|
||||
print("Finished:the parameters have been saved into mindspore checkpoint.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Read GPT-2 model checkpoint weight")
|
||||
parser.add_argument("--output_file_name", type=str, default="",
|
||||
help="The name of output checkpoint name")
|
||||
args_opt = parser.parse_args()
|
||||
ckpt_name = args_opt.output_file_name
|
||||
trans_model_parameter(ckpt_name=ckpt_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,148 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord data for Children's Book Test task"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.utils.tokenization import Tokenizer
|
||||
|
||||
|
||||
def create_instance(tokenizer, text, max_length=None, num_choice=None):
|
||||
"""A single sample instance for cbt task."""
|
||||
text = text.replace(" \t ", "\t ")
|
||||
sentence = text.strip().split("\t")
|
||||
context_length = len(tokenizer.encode(sentence[0]))
|
||||
|
||||
whole_sentence = sentence[0] + sentence[1]
|
||||
whole_sentence = whole_sentence.strip()
|
||||
assert whole_sentence != ""
|
||||
print(" | whole sentence: ", whole_sentence)
|
||||
ids = tokenizer.encode(whole_sentence)
|
||||
input_length = len(ids)
|
||||
pair_ids = None
|
||||
|
||||
output = tokenizer.prepare_for_model(ids=ids,
|
||||
pair_ids=pair_ids,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
truncate_direction="RIGHT",
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True)
|
||||
|
||||
output["length"] = [context_length + 1] + [input_length + 1]
|
||||
|
||||
gold_answer_id = int(sentence[2])
|
||||
assert gold_answer_id < 10
|
||||
output["mc_labels"] = gold_answer_id
|
||||
|
||||
for name, value in output.items():
|
||||
print(name)
|
||||
print(value)
|
||||
print("==================================")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance):
|
||||
"""write the instance to file"""
|
||||
input_ids = instance["input_ids"]
|
||||
input_mask = instance["attention_mask"]
|
||||
assert len(input_ids) == len(input_mask)
|
||||
length = instance["length"] # list
|
||||
mc_labels = instance["mc_labels"]
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = np.asarray(input_ids)
|
||||
features["input_mask"] = np.asarray(input_mask)
|
||||
features["input_length"] = np.asarray(length)
|
||||
features["mc_labels"] = mc_labels
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True, default="", help='Input raw text file. ')
|
||||
parser.add_argument("--output_file", type=str, required=True, default="", help='Output MindRecord file. ')
|
||||
parser.add_argument("--num_splits", type=int, default=1,
|
||||
help='The MindRecord file will be split into the number of partition. ')
|
||||
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
|
||||
parser.add_argument("--num_choice", type=int, required=True, help='Number of choices. ')
|
||||
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
|
||||
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
|
||||
num_choice = args.num_choice
|
||||
|
||||
input_file = args.input_file
|
||||
logging.info("***** Reading from input files *****")
|
||||
logging.info("Input File: %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("***** Writing to output files *****")
|
||||
logging.info("Output File: %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"input_length": {"type": "int64", "shape": [-1]},
|
||||
"mc_labels": {"type": "int64"}
|
||||
}
|
||||
writer.add_schema(data_schema, "cbt-schema")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
logging.info("***** Reading from %s *****", input_file)
|
||||
with open(input_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 500 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
output = create_instance(tokenizer, line, args.max_seq_length, num_choice)
|
||||
features = write_instance_to_file(writer, instance=output)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("***** Example *****")
|
||||
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
|
||||
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,140 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord data for LAMBADA task"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.utils.tokenization import Tokenizer
|
||||
|
||||
|
||||
def create_instance(tokenizer, text, max_length=None):
|
||||
"""A single sample instance for LAMBADA task."""
|
||||
text = text.replace(" \t ", "\t ")
|
||||
sentence = text.strip().split("\t")
|
||||
context_length = len(tokenizer.encode(sentence[0]))
|
||||
|
||||
whole_sentence = sentence[0] + sentence[1]
|
||||
whole_sentence = whole_sentence.strip()
|
||||
assert whole_sentence != ""
|
||||
print(" | whole sentence: ", whole_sentence)
|
||||
ids = tokenizer.encode(whole_sentence)
|
||||
input_length = len(ids)
|
||||
pair_ids = None
|
||||
|
||||
output = tokenizer.prepare_for_model(ids=ids,
|
||||
pair_ids=pair_ids,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
truncate_direction="RIGHT",
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True)
|
||||
|
||||
# input_length = <bos> + text_length, not include <eos>
|
||||
output["length"] = [context_length + 1] + [input_length + 1]
|
||||
|
||||
for k, v in output.items():
|
||||
print(k)
|
||||
print(v)
|
||||
print("==================================")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance):
|
||||
"""write the instance to file"""
|
||||
input_ids = instance["input_ids"]
|
||||
input_mask = instance["attention_mask"]
|
||||
assert len(input_ids) == len(input_mask)
|
||||
length = instance["length"] # list
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = np.asarray(input_ids)
|
||||
features["input_mask"] = np.asarray(input_mask)
|
||||
features["input_length"] = np.asarray(length)
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ')
|
||||
parser.add_argument("--num_splits", type=int, default=1,
|
||||
help='The MindRecord file will be split into the number of partition. ')
|
||||
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
|
||||
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
|
||||
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
|
||||
|
||||
input_file = args.input_file
|
||||
logging.info("***** Reading from input files *****")
|
||||
logging.info("Input File: %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("***** Writing to output files *****")
|
||||
logging.info("Output File: %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"input_length": {"type": "int64", "shape": [-1]},
|
||||
}
|
||||
writer.add_schema(data_schema, "lambada-schema")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
logging.info("***** Reading from %s *****", input_file)
|
||||
with open(input_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 500 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
output = create_instance(tokenizer, line, args.max_seq_length)
|
||||
features = write_instance_to_file(writer, instance=output)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("***** Example *****")
|
||||
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
|
||||
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,126 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord data for LM task"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.utils.tokenization import Tokenizer
|
||||
|
||||
|
||||
def create_instance(tokenizer, text, max_length=None):
|
||||
"""A single sample instance for LM task."""
|
||||
sentence = text.strip().split("\t")
|
||||
|
||||
ids = tokenizer.encode(sentence[0])
|
||||
pair_ids = None
|
||||
if len(sentence) == 2:
|
||||
pair_ids = tokenizer.encode(sentence[1])
|
||||
|
||||
output = tokenizer.prepare_for_model(ids=ids,
|
||||
pair_ids=pair_ids,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
truncate_direction="LEFT",
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True)
|
||||
return output
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance):
|
||||
"""write the instance to file"""
|
||||
input_ids = instance["input_ids"]
|
||||
input_mask = instance["attention_mask"]
|
||||
label_ids = instance["input_ids"]
|
||||
assert len(input_ids) == len(label_ids)
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = np.asarray(input_ids)
|
||||
features["input_mask"] = np.asarray(input_mask)
|
||||
features["label_ids"] = np.asarray(label_ids)
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file. ')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file. ')
|
||||
parser.add_argument("--num_splits", type=int, default=1,
|
||||
help='The MindRecord file will be split into the number of partition. ')
|
||||
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length. ')
|
||||
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
|
||||
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file)
|
||||
|
||||
input_file = args.input_file
|
||||
logging.info("***** Reading from input files *****")
|
||||
logging.info("Input File: %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("***** Writing to output files *****")
|
||||
logging.info("Output File: %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"label_ids": {"type": "int64", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(data_schema, "lm-schema")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
logging.info("***** Reading from %s *****", input_file)
|
||||
with open(input_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 500 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
output = create_instance(tokenizer, line, args.max_seq_length)
|
||||
features = write_instance_to_file(writer, instance=output)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("***** Example *****")
|
||||
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
|
||||
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,130 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""create mindrecord data for Summarization task"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
from src.utils import tokenization
|
||||
|
||||
|
||||
def create_instance(tokenizer, text, max_length=None):
|
||||
"""A single sample instance for Summarization task."""
|
||||
sentence = text.strip().split("\t")
|
||||
ids = tokenizer.encode(sentence[0])
|
||||
pair_ids = None
|
||||
if len(sentence) == 2:
|
||||
pair_ids = tokenizer.encode(sentence[1])
|
||||
if len(sentence) >= 3:
|
||||
article = sentence[0]
|
||||
for i in range(1, len(sentence) - 1):
|
||||
article += sentence[i]
|
||||
ids = tokenizer.encode(article)
|
||||
pair_ids = tokenizer.encode(sentence[-1])
|
||||
|
||||
output = tokenizer.prepare_for_model(ids=ids,
|
||||
pair_ids=pair_ids,
|
||||
add_special_tokens=True,
|
||||
max_length=max_length,
|
||||
padding=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_attention_mask=True)
|
||||
return output
|
||||
|
||||
|
||||
def write_instance_to_file(writer, instance):
|
||||
"""write the instance to file"""
|
||||
input_ids = instance["input_ids"]
|
||||
input_mask = instance["attention_mask"]
|
||||
label_ids = instance["input_ids"]
|
||||
assert len(input_ids) == len(label_ids)
|
||||
|
||||
features = collections.OrderedDict()
|
||||
features["input_ids"] = np.asarray(input_ids)
|
||||
features["input_mask"] = np.asarray(input_mask)
|
||||
features["label_ids"] = np.asarray(label_ids)
|
||||
|
||||
writer.write_raw_data([features])
|
||||
return features
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_file", type=str, required=True, help='Input raw text file.')
|
||||
parser.add_argument("--output_file", type=str, required=True, help='Output MindRecord file.')
|
||||
parser.add_argument("--num_splits", type=int, default=1,
|
||||
help='The MindRecord file will be split into the number of partition. ')
|
||||
parser.add_argument("--max_seq_length", type=int, required=True, help='Maximum sequence length.')
|
||||
parser.add_argument("--vocab_file", type=str, required=True, default='', help='url of gpt2-vocab.json ')
|
||||
parser.add_argument("--merge_file", type=str, required=True, default='', help='url of gpt2-merges.txt ')
|
||||
parser.add_argument("--mode", type=str, required=True, default='cnn_dailymail', help='mode of dataset creation')
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer = tokenization.Tokenizer(vocab_file=args.vocab_file, merge_file=args.merge_file, mode=args.mode)
|
||||
input_file = args.input_file
|
||||
logging.info("***** Reading from input files *****")
|
||||
logging.info("Input File: %s", input_file)
|
||||
|
||||
output_file = args.output_file
|
||||
logging.info("***** Writing to output files *****")
|
||||
logging.info("Output File: %s", output_file)
|
||||
|
||||
writer = FileWriter(output_file, args.num_splits)
|
||||
data_schema = {"input_ids": {"type": "int64", "shape": [-1]},
|
||||
"input_mask": {"type": "int64", "shape": [-1]},
|
||||
"label_ids": {"type": "int64", "shape": [-1]}
|
||||
}
|
||||
writer.add_schema(data_schema, "wikitext2-schema")
|
||||
|
||||
total_written = 0
|
||||
total_read = 0
|
||||
|
||||
logging.info("***** Reading from %s *****", input_file)
|
||||
with open(input_file, "r") as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
total_read += 1
|
||||
if total_read % 500 == 0:
|
||||
logging.info("%d ...", total_read)
|
||||
|
||||
output = create_instance(tokenizer, line, args.max_seq_length)
|
||||
features = write_instance_to_file(writer, instance=output)
|
||||
total_written += 1
|
||||
|
||||
if total_written <= 20:
|
||||
logging.info("***** Example *****")
|
||||
logging.info("input tokens: %s", tokenizer.decode(output["input_ids"][:-1]))
|
||||
logging.info("label tokens: %s", tokenizer.decode(output["input_ids"][1:]))
|
||||
|
||||
for feature_name in features.keys():
|
||||
feature = features[feature_name]
|
||||
logging.info("%s: %s", feature_name, feature)
|
||||
|
||||
writer.commit()
|
||||
logging.info("Wrote %d total instances", total_written)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,59 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""download the CNN & DailyMail for Summarization task"""
|
||||
|
||||
import argparse
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def generate_txt(url, split_, number=None, version="3.0.0"):
|
||||
"""
|
||||
generate txt file of cnn_dailymail dataset
|
||||
|
||||
Args:
|
||||
url (str): directory of dataset txt file.
|
||||
split_ (str): test or train.
|
||||
number (int): top-n number of samples from dataset
|
||||
version (str): "3.0.0" by default
|
||||
|
||||
"""
|
||||
cnn = load_dataset("cnn_dailymail", version, split=split_)
|
||||
if number == -1:
|
||||
number = len(cnn)
|
||||
f = open(url + split_ + '.txt', 'w')
|
||||
for idx in range(number):
|
||||
article = cnn[idx]['article']
|
||||
article = article.replace('\n', ' ')
|
||||
highlights = cnn[idx]['highlights']
|
||||
highlights = highlights.replace('\n', ' ')
|
||||
f.write(article + "\t" + highlights + '\n')
|
||||
f.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Download CNN_Dailymail 3.0.0 using datasets by Huggingface')
|
||||
parser.add_argument('--dir', type=str, default="", help="directory of dataset")
|
||||
parser.add_argument('--split', type=str, default='test', help="[test,train]")
|
||||
parser.add_argument('--num', type=int, default=-1,
|
||||
help=" number of samples by default order. "
|
||||
"If num is -1, it will download whole dataset. Default: -1")
|
||||
args = parser.parse_args()
|
||||
|
||||
data_directory = args.dir
|
||||
split = args.split
|
||||
num = args.num
|
||||
|
||||
generate_txt(url=data_directory, split_=split, number=num)
|
@ -0,0 +1,135 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation reading comprehension result with additional answer."""
|
||||
|
||||
import json
|
||||
import re
|
||||
import string
|
||||
import argparse
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def get_normalize_answer_token(string_):
|
||||
"""normalize the answer token, Lower text and remove punctuation, article and extra whitespace"""
|
||||
def remove_articles(text):
|
||||
regex = re.compile(r'\b(a|an|the)\b', re.UNICODE)
|
||||
return re.sub(regex, ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(char for char in text if char not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(string_)))).split()
|
||||
|
||||
|
||||
def calculate_f1(pred_answer, gold_answer):
|
||||
"""
|
||||
calculate final F1 score with addition answer
|
||||
"""
|
||||
f1_score = 0
|
||||
pred_answer = get_normalize_answer_token(pred_answer)
|
||||
gold_answer = get_normalize_answer_token(gold_answer)
|
||||
common = Counter(pred_answer) & Counter(gold_answer)
|
||||
num_same = sum(common.values())
|
||||
# the number of same tokens between pred_answer and gold_answer
|
||||
precision = 1.0 * num_same / len(pred_answer) if pred_answer.strip() == "" else 0
|
||||
recall = 1.0 * num_same / len(gold_answer) if gold_answer.strip() == "" else 0
|
||||
if pred_answer.strip() == "" and gold_answer.strip() == "":
|
||||
f1_score = 1
|
||||
else:
|
||||
f1_score = 2 * precision * recall / float(precision + recall) if (precision + recall) != 0 else 0.0
|
||||
return f1_score
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="All Task dataset preprocessing")
|
||||
parser.add_argument("--input_file", type=str, default="",
|
||||
help="The log file path of evaluation in Reading Comprehension. ")
|
||||
parser.add_argument("--addition_file", type=str, default="", help="Coqa-dev-v1.0.json path")
|
||||
args_opt = parser.parse_args()
|
||||
input_file = args_opt.input_file
|
||||
addition_file = args_opt.addition_file
|
||||
|
||||
find_word = 'Pred_answer:'
|
||||
find_word_length = len(find_word)
|
||||
pred_answer_list = []
|
||||
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
while True:
|
||||
line = f.readline()
|
||||
if not line:
|
||||
break
|
||||
index = line.find(find_word)
|
||||
if index != -1:
|
||||
pred_answer = line[index + find_word_length:].strip()
|
||||
pred_answer_list.append(pred_answer)
|
||||
|
||||
dataset = json.load(open(addition_file))
|
||||
pred_answer_num = 0
|
||||
total_f1score = 0
|
||||
average_f1score = 0
|
||||
data_num = len(pred_answer_list)
|
||||
|
||||
for story in dataset['data']:
|
||||
questions = story['questions']
|
||||
multiple_answers = [story['answers']]
|
||||
multiple_answers += story['additional_answers'].values()
|
||||
for question in questions:
|
||||
pred_a = pred_answer_list[pred_answer_num]
|
||||
turn_id = question['turn_id']
|
||||
max_score = 0
|
||||
max_group = 0
|
||||
flag = 0
|
||||
for i, answer in enumerate(multiple_answers):
|
||||
gold_a = answer[turn_id - 1]['input_text']
|
||||
score = calculate_f1(pred_a, gold_a)
|
||||
if score > max_score:
|
||||
max_score = score
|
||||
max_group = i
|
||||
# calculate the max score in multiple answers and record it's number.
|
||||
gold_a = multiple_answers[max_group][turn_id - 1]['input_text']
|
||||
pred_answer_num += 1
|
||||
total_f1score += max_score
|
||||
average_f1score = total_f1score / pred_answer_num
|
||||
|
||||
print('==================== data {} ===================='.format(pred_answer_num))
|
||||
print('| Gold_answer:{}'.format(gold_a))
|
||||
print('| Pred_answer:{}'.format(pred_a))
|
||||
print('| F1_Score:{:.8f}'.format(average_f1score))
|
||||
print('=====================================================\n')
|
||||
|
||||
if pred_answer_num >= data_num:
|
||||
flag = 1
|
||||
break
|
||||
# Stop flag
|
||||
if flag:
|
||||
print('Finished evaluation with addition answer! \n')
|
||||
print("********************** Testing Finished **********************")
|
||||
print('| Test file name: {}'.format(input_file))
|
||||
print('| Final F1 score: {:.8f}'.format(average_f1score))
|
||||
print('| Total data num: {}'.format(pred_answer_num))
|
||||
print("**************************************************************")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,60 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_cbt.sh"
|
||||
echo "for example: bash scripts/run_cbt.sh"
|
||||
echo "metric method: Accuracy"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_cbt.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_CBT_task.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=4 \
|
||||
--num_choice=10 \
|
||||
--metric_method="Accuracy" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 &
|
@ -0,0 +1,68 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_lambada.sh"
|
||||
echo "for example: bash scripts/run_lambada.sh"
|
||||
echo "method metric include: [Accuracy, PPL]"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_lambada.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
# tokenizer path
|
||||
tokenizer_file_path=""
|
||||
|
||||
# stopword path
|
||||
stop_word_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_lambada.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=1 \
|
||||
--metric_method="PPL" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--generate_length_dynamically="true" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path \
|
||||
--tokenizer_file_path=$tokenizer_file_path \
|
||||
--stop_word_file_path=$stop_word_file_path >> $output_log 2>&1 &
|
@ -0,0 +1,59 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_language_model.sh"
|
||||
echo "for example: bash scripts/run_language_model.sh"
|
||||
echo "metric method: PPL"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_language_model.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_language_model.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=4 \
|
||||
--metric_method="PPL" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path >> $output_log 2>&1 &
|
@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_read_comprehension.sh"
|
||||
echo "for example: bash scripts/run_read_comprehension.sh"
|
||||
echo "metric method: F1"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_read_comprehension.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
# tokenizer path
|
||||
tokenizer_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_ReadComprehension.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=7 \
|
||||
--metric_method="F1" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path \
|
||||
--tokenizer_file_path=$tokenizer_file_path \
|
||||
--generate_length=55 \
|
||||
--top_k=1 \
|
||||
--top_p="1.0" \
|
||||
--temperature="1.0" >> $output_log 2>&1 &
|
@ -0,0 +1,66 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_summarization.sh"
|
||||
echo "for example: bash scripts/run_summarization.sh"
|
||||
echo "eval_load_param_mode include: [zero-shot, finetuned]. Default: finetuned"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_summarization.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
# tokenizer path
|
||||
tokenizer_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_summarization.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=0 \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--metric_method="Rouge" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--top_k=2 \
|
||||
--top_p="1.0" \
|
||||
--generate_length=100 \
|
||||
--temperature="1.0" \
|
||||
--eval_type="finetuned" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path \
|
||||
--tokenizer_file_path=$tokenizer_file_path >> $output_log 2>&1 &
|
@ -0,0 +1,67 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_translation.sh"
|
||||
echo "for example: bash scripts/run_translation.sh"
|
||||
echo "metric method: BLEU"
|
||||
echo "eval_type include: [zero-shot, finetuned]. Default: zero-shot"
|
||||
echo "=============================================================================================================="
|
||||
|
||||
CUR_DIR=`pwd`
|
||||
mkdir -p ms_log
|
||||
output_log="${CUR_DIR}/ms_log/gpt2_translation.log"
|
||||
|
||||
# create file and head line
|
||||
echo " | Eval log file: " > $output_log
|
||||
echo $output_log >> $output_log
|
||||
|
||||
# checkpoint path
|
||||
save_finetune_ckpt_path=""
|
||||
load_pretrain_ckpt_path=""
|
||||
load_eval_ckpt_path=""
|
||||
|
||||
# dataset path
|
||||
train_data_file_path=""
|
||||
eval_data_file_path=""
|
||||
|
||||
# tokenizer path
|
||||
tokenizer_file_path=""
|
||||
|
||||
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)
|
||||
export GLOG_log_dir=${CUR_DIR}/ms_log
|
||||
export GLOG_logtostderr=0
|
||||
python ${PROJECT_DIR}/../run_translation.py \
|
||||
--device_target="Ascend" \
|
||||
--device_id=4 \
|
||||
--metric_method="BLEU" \
|
||||
--do_train="false" \
|
||||
--do_eval="true" \
|
||||
--eval_type="zero-shot" \
|
||||
--epoch_num=1 \
|
||||
--train_data_shuffle="true" \
|
||||
--eval_data_shuffle="false" \
|
||||
--save_finetune_ckpt_path=$save_finetune_ckpt_path \
|
||||
--load_pretrain_ckpt_path=$load_pretrain_ckpt_path \
|
||||
--load_finetune_ckpt_path=$load_eval_ckpt_path \
|
||||
--train_data_file_path=$train_data_file_path \
|
||||
--eval_data_file_path=$eval_data_file_path \
|
||||
--tokenizer_file_path=$tokenizer_file_path \
|
||||
--generate_length=100 \
|
||||
--top_k=1 \
|
||||
--top_p="1.0" \
|
||||
--temperature="1.0" >> $output_log 2>&1 &
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue