You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
297 lines
15 KiB
297 lines
15 KiB
# -*- coding: utf-8 -*-
|
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""
|
|
GPT-2 finetune and evaluation script for Summarization task.
|
|
"""
|
|
|
|
import time
|
|
import argparse
|
|
|
|
from mindspore import context
|
|
from mindspore.nn import AdamWeightDecay, Lamb, Momentum
|
|
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
|
from mindspore.train.model import Model
|
|
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, TimeMonitor, LossMonitor
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
from src.GPT2ForSummarization import GPT2SummarizationModel
|
|
from src.gpt2_for_finetune import GPT2Summarization, GPT2FinetuneCell
|
|
from src.finetune_eval_config import cfg, gpt2_net_cfg
|
|
from src.utils.metric_method import Rouge
|
|
from src.dataset import create_language_model_dataset
|
|
from src.utils.lr_schedule import GPT2LearningRate
|
|
from src.utils.tokenization import Tokenizer
|
|
from src.utils.task_utils import clean_hypo, modify_paramdict
|
|
from src.GPT2_generation import GenerateForSummarization
|
|
|
|
|
|
def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoint_path="", epoch_num=1):
|
|
"""
|
|
Do train
|
|
Args:
|
|
dataset: the train dataset.
|
|
network: the network with loss
|
|
load_checkpoint_path: the file path which saved pretrain model checkpoint.
|
|
save_checkpoint_path: the file path which will save finetune model checkpoint.
|
|
epoch_num: the number of epoch
|
|
"""
|
|
if load_checkpoint_path == "":
|
|
raise ValueError("Pretrain model missed, finetune task must load pretrain model!")
|
|
|
|
steps_per_epoch = dataset.get_dataset_size()
|
|
|
|
# optimizer
|
|
if cfg.optimizer == 'AdamWeightDecay':
|
|
lr_schedule = GPT2LearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
|
|
end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
|
|
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
|
decay_steps=steps_per_epoch * epoch_num,
|
|
power=cfg.AdamWeightDecay.power)
|
|
params = network.trainable_params()
|
|
|
|
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
|
|
other_params = list(
|
|
filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
|
|
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
|
|
{'params': other_params, 'weight_decay': 0.0}]
|
|
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=cfg.AdamWeightDecay.eps)
|
|
elif cfg.optimizer == 'Lamb':
|
|
lr_schedule = GPT2LearningRate(learning_rate=cfg.Lamb.learning_rate,
|
|
end_learning_rate=cfg.Lamb.end_learning_rate,
|
|
warmup_steps=int(steps_per_epoch * epoch_num * 0.1),
|
|
decay_steps=steps_per_epoch * epoch_num,
|
|
power=cfg.Lamb.power)
|
|
optimizer = Lamb(network.trainable_params(), lr_schedule)
|
|
elif cfg.optimizer == 'Momentum':
|
|
optimizer = Momentum(network.trainable_params(), cfg.Momentum.learning_rate, cfg.Momentum.momentum)
|
|
else:
|
|
raise Exception("Optimizer not supported. support: [AdamWeightDecay, Lamb, Momentum]")
|
|
|
|
# load checkpoint into network
|
|
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch, keep_checkpoint_max=1)
|
|
prefix_name = "gpt2_summarization_" + str(cfg.gpt2_network) + "_" + str(cfg.optimizer) + "_" \
|
|
+ str(epoch_num) + "_bs" + str(gpt2_net_cfg.batch_size)
|
|
ckpoint_cb = ModelCheckpoint(prefix=prefix_name,
|
|
directory=None if save_checkpoint_path == "" else save_checkpoint_path,
|
|
config=ckpt_config)
|
|
param_dict = load_checkpoint(load_checkpoint_path)
|
|
|
|
final_param_dict = {}
|
|
for name, _ in param_dict.items():
|
|
final_param_dict['gpt2.gpt2.' + name] = param_dict[name]
|
|
final_param_dict['gpt2.lm_head.weight'] = param_dict['gpt2_embedding_lookup.embedding_table']
|
|
|
|
load_param_into_net(network, final_param_dict)
|
|
print("Load pretrained parameter successfully!\n")
|
|
|
|
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 32, scale_factor=2, scale_window=1000)
|
|
netwithgrads = GPT2FinetuneCell(network, optimizer=optimizer, scale_update_cell=update_cell)
|
|
netwithgrads.set_train(True)
|
|
|
|
loss_cb = LossMonitor(per_print_times=1)
|
|
model = Model(netwithgrads)
|
|
callbacks = [TimeMonitor(dataset.get_dataset_size()), loss_cb, ckpoint_cb]
|
|
|
|
print("============== Starting Finetuning ==============")
|
|
model.train(epoch_num, dataset, callbacks=callbacks, dataset_sink_mode=False)
|
|
print("============== Finetuning Success ==============")
|
|
|
|
|
|
def eval_result_print(metric="Rouge", callback=None):
|
|
"""
|
|
print eval result
|
|
"""
|
|
if metric == "Rouge":
|
|
print("Rouge-1 {:.8f}, Rouge-2 {:.8f}, Rouge-L {:.8f}, Rouge-AVG{:.8f}".
|
|
format(callback.Rouge1 / callback.total_num,
|
|
callback.Rouge2 / callback.total_num,
|
|
callback.RougeL / callback.total_num,
|
|
(callback.Rouge1 + callback.Rouge2 + callback.RougeL) / (3.0 * callback.total_num)))
|
|
else:
|
|
raise ValueError("metric method '{}' not supported, support: [Rouge]. ".format(str(metric)))
|
|
|
|
|
|
def do_eval(dataset=None, network=None, metric=None, load_checkpoint_path="", eval_type=None, tokenizer_file="",
|
|
top_k=None, top_p=None, temperature=None, generate_length=None):
|
|
"""
|
|
Do evaluation on summarization
|
|
"""
|
|
if load_checkpoint_path == "":
|
|
raise ValueError("Finetune model missed, evaluation task must load finetune model!")
|
|
if metric.lower() == "rouge":
|
|
print("Prepare to calculate the Rouge score ...")
|
|
callback = Rouge()
|
|
|
|
gpt2_loss = network(config=gpt2_net_cfg,
|
|
is_training=False,
|
|
use_one_hot_embeddings=False)
|
|
gpt2_loss.set_train(False)
|
|
param_dict = load_checkpoint(load_checkpoint_path)
|
|
|
|
reorganized_param_dict = modify_paramdict(param_dict, mode=eval_type, model_prefix="gpt2.")
|
|
load_param_into_net(gpt2_loss, reorganized_param_dict)
|
|
|
|
# load nn.Cell into Model and initiate tokenizer and Sample
|
|
model = Model(gpt2_loss)
|
|
tokenizer = Tokenizer(vocab_file=tokenizer_file + 'gpt2-vocab.json',
|
|
merge_file=tokenizer_file + 'gpt2-merges.txt')
|
|
|
|
# load data and process text generation
|
|
columns_list = ["input_ids", "input_mask", "label_ids"]
|
|
|
|
summarization_generator = GenerateForSummarization(model,
|
|
config=gpt2_net_cfg,
|
|
tokenizer=tokenizer,
|
|
select_sentence=3,
|
|
eval_type=eval_type,
|
|
topk=top_k,
|
|
topp=float(top_p),
|
|
temperature=float(temperature),
|
|
generate_length=generate_length)
|
|
num_data = 1
|
|
print("==================== [Summrization] Testing ====================")
|
|
for data in dataset.create_dict_iterator():
|
|
input_data = []
|
|
for value in columns_list:
|
|
input_data.append(data[value])
|
|
input_ids, _, label_ids = input_data
|
|
print(" | [ROUGE] number : {} / {} ".format(num_data, dataset.get_dataset_size()))
|
|
print("input_ids shape: {}".format(input_ids.shape))
|
|
print("label_ids shape: {}".format(label_ids.shape))
|
|
|
|
hypothesis, ref = summarization_generator.generate_for_summarization(input_ids)
|
|
if ref[0] == '' or ref[0] is None:
|
|
print("Sorry ref_list is None, skip it!")
|
|
continue
|
|
|
|
print("REF str:\n ", ref, "\nHYPO str:\n", hypothesis, "\n")
|
|
for batch_idx in range(gpt2_net_cfg.batch_size):
|
|
hypothesis[batch_idx] = clean_hypo(hypothesis[batch_idx])
|
|
for batch_idx in range(gpt2_net_cfg.batch_size):
|
|
hypothesis[batch_idx] = hypothesis[batch_idx].lower()
|
|
ref[batch_idx] = ref[batch_idx].lower()
|
|
|
|
callback.update(hypothesis, ref)
|
|
num_data += 1
|
|
|
|
print("\n\n")
|
|
print("**********************************************************")
|
|
eval_result_print(metric, callback)
|
|
print("******************** Testing Finished ********************")
|
|
|
|
else:
|
|
raise ValueError("metric method not supported in summarization, support: [Rouge]")
|
|
|
|
|
|
def run_summarization():
|
|
"""
|
|
Run Summarization task.
|
|
"""
|
|
# set argument parser
|
|
parser = argparse.ArgumentParser(description="Finetune and Evaluate Summrization")
|
|
|
|
# context and task settings
|
|
parser.add_argument("--device_target", type=str, default="Ascend",
|
|
help="Device type. Default: Ascend.")
|
|
parser.add_argument("--device_id", type=int, default=4,
|
|
help="ID of target device.")
|
|
parser.add_argument("--do_train", type=str, default="false",
|
|
help="Enable train. Default: false.")
|
|
parser.add_argument("--do_eval", type=str, default="true",
|
|
help="Enable evaluation. Default: false.")
|
|
parser.add_argument("--eval_type", type=str, default="finetuned",
|
|
help="The type of evaluation including [zero-shot, finetuned]. Default: zero-shot.")
|
|
parser.add_argument("--metric_method", type=str, default="Rouge",
|
|
help="The eval method including [Rouge(Rouge1,Rouge2,RougeL,Rouge Avg)]. Default: Rouge.")
|
|
parser.add_argument("--epoch_num", type=int, default=2,
|
|
help="Epoch number. Default: 2.")
|
|
|
|
# dataset and params_dict file settings
|
|
parser.add_argument("--train_data_shuffle", type=str, default="true",
|
|
help="Enable train data shuffle. Default: true.")
|
|
parser.add_argument("--eval_data_shuffle", type=str, default="false",
|
|
help="Enable eval data shuffle. Default: false.")
|
|
parser.add_argument("--save_finetune_ckpt_path", type=str, default="",
|
|
help="Save the checkpoint path.")
|
|
parser.add_argument("--load_pretrain_ckpt_path", type=str, default="",
|
|
help="Load the checkpoint file path.")
|
|
parser.add_argument("--load_finetune_ckpt_path", type=str, default="",
|
|
help="Load the checkpoint file path.")
|
|
parser.add_argument("--train_data_file_path", type=str, default="",
|
|
help="Data path, it is better to use absolute path")
|
|
parser.add_argument("--eval_data_file_path", type=str, default="",
|
|
help="Data path, it is better to use absolute path")
|
|
|
|
# sampling settings
|
|
parser.add_argument("--top_k", type=int, default=2,
|
|
help="top k tokens chosen for sampling")
|
|
parser.add_argument("--top_p", type=str, default="1.0",
|
|
help="top p accumulated probability threshold for logit to be counted")
|
|
parser.add_argument("--generate_length", type=int, default=100,
|
|
help="the number of generated tokens.")
|
|
parser.add_argument("--temperature", type=str, default="1.0",
|
|
help="temperature on logits for sampling")
|
|
parser.add_argument("--tokenizer_file_path", type=str, default="",
|
|
help="vocab & merge file path")
|
|
args_opt = parser.parse_args()
|
|
|
|
epoch_num = args_opt.epoch_num
|
|
metric = args_opt.metric_method
|
|
save_finetune_ckpt_path = args_opt.save_finetune_ckpt_path
|
|
load_finetune_ckpt_path = args_opt.load_finetune_ckpt_path
|
|
load_pretrain_ckpt_path = args_opt.load_pretrain_ckpt_path
|
|
eval_type = args_opt.eval_type
|
|
tokenizer_file = args_opt.tokenizer_file_path
|
|
|
|
if args_opt.do_train.lower() == "false" and args_opt.do_eval.lower() == "false":
|
|
raise ValueError("At least one of 'do_train' or 'do_eval' must be true")
|
|
if args_opt.do_train.lower() == "true" and args_opt.train_data_file_path == "":
|
|
raise ValueError("'train_data_file_path' must be set when do finetune task")
|
|
if args_opt.do_eval.lower() == "true" and args_opt.eval_data_file_path == "":
|
|
raise ValueError("'eval_data_file_path' must be set when do evaluation task")
|
|
|
|
device = args_opt.device_target
|
|
if device == "Ascend":
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
|
|
context.set_auto_parallel_context(parallel_mode="stand_alone")
|
|
print(" | Device: {} | Device id: {}".format(device, args_opt.device_id))
|
|
else:
|
|
raise Exception("Device target error, Ascend is supported.")
|
|
|
|
if args_opt.do_train.lower() == "true":
|
|
train_data_file_path = args_opt.train_data_file_path
|
|
gpt2_loss = GPT2Summarization(config=gpt2_net_cfg,
|
|
is_training=True,
|
|
use_one_hot_embeddings=False)
|
|
print("============== Start Loading Train Dataset ============")
|
|
train_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
|
dataset_path=train_data_file_path)
|
|
do_train(train_dataset, gpt2_loss, load_pretrain_ckpt_path, save_finetune_ckpt_path, epoch_num)
|
|
|
|
if args_opt.do_eval.lower() == "true":
|
|
eval_dataset_file_path = args_opt.eval_data_file_path
|
|
print("============== Start Loading Evaluation Dataset ============")
|
|
eval_dataset = create_language_model_dataset(do_shuffle=(args_opt.train_data_shuffle.lower() == "true"),
|
|
dataset_path=eval_dataset_file_path)
|
|
do_eval(eval_dataset, GPT2SummarizationModel, metric, load_finetune_ckpt_path, eval_type, tokenizer_file,
|
|
args_opt.top_k, args_opt.top_p, args_opt.temperature, args_opt.generate_length)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("Start Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
|
run_summarization()
|
|
print("End Time: ", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|