You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/official/nlp/gpt/eval.py

156 lines
5.1 KiB

# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
GPT evaluation script.
"""
import math
import argparse
import numpy as np
from mindspore import context
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.inference import generate
from src.dataset import create_dataset
from src.gpt import GPT, EvalNet, GPTWithLoss, CrossEntropyLoss
from src.utils import GPTConfig
context.set_context(mode=context.GRAPH_MODE)
def ppl_score(probs, length, is_logsoftmax=True):
""" calculate perplexity with prob or log_prob inputs """
probs = probs[:length]
if is_logsoftmax:
prob = np.sum(probs) / length
ppl = 1.0 / np.power(np.e, prob)
else:
prob = 1.0
for p in probs:
prob *= (1.0 / p)
ppl = np.power(prob, 1.0/length)
return ppl
def get_ppl(model, dataset):
""" calculate perplexity for input dataset """
PPL = []
tokens = 0
for data in dataset:
data = data[0].asnumpy()
input_ids = data
logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
PPL.append(logits * len(data))
tokens += len(data)
val_loss = sum(PPL) / tokens
ppl = math.exp(min(20, val_loss))
return ppl
def get_acc(model, dataset):
""" calculate accuracy for input dataset """
total_num = 0
acc_num = 0
for data in dataset:
data = data[0].asnumpy()
input_mask = (data != 0).astype(np.int32)
length = np.sum(input_mask, 1)
label = np.zeros(length.shape)
for i, idx in enumerate(length):
label[i] = data[i][idx-1]
input_mask[i][idx-1] = 0
data[i][idx-1] = 0
length = np.sum(data != 50256, 1)
input_ids = data
logits = model(Tensor(input_ids, mstype.int32)).asnumpy()
logits = logits.reshape(len(length), -1)
predicted_label = np.zeros(length.shape)
for i, idx in enumerate(length):
predicted_label[i] = logits[i][idx-2]
total_num += len(label)
acc_num += sum(label == predicted_label)
acc = acc_num / total_num
return acc
def run_eval():
""" evaluate scripts """
parser = argparse.ArgumentParser(description="GPT inferencing")
parser.add_argument('--task_type', type=str, default="", help="Evaluation task.")
parser.add_argument('--metrics', type=str, default="acc", choices=["ppl", "acc"], help="Evaluation metrics.")
parser.add_argument('--ckpt_path', type=str, default="", help="path of checkpoint file.")
parser.add_argument('--data_path', type=str, default="", help="path of MindRecord file.")
args = parser.parse_args()
task = args.task_type
metrics = args.metrics
ckpt_path = args.ckpt_path
if task not in ["generate", "lambada", "wikitext"]:
raise ValueError("{} is not supported now".format(task))
if metrics not in ["acc", "ppl"]:
raise ValueError("{} is not supported now".format(metrics))
config = GPTConfig(batch_size=16,
seq_length=1024,
vocab_size=50257,
embedding_size=1024,
num_layers=24,
num_heads=16,
expand_ratio=4,
post_layernorm_residual=False,
dropout_rate=0.0,
compute_dtype=mstype.float16,
use_past=False)
ckpt_dict = load_checkpoint(ckpt_path)
gpt = GPT(config)
if task == "generate":
gpt_eval = EvalNet(gpt, generate=True)
elif metrics == "acc":
gpt_eval = EvalNet(gpt, generate=False)
else:
loss = CrossEntropyLoss(config)
gpt_eval = GPTWithLoss(gpt, loss)
gpt_eval.set_train(False)
load_param_into_net(gpt_eval, ckpt_dict)
if task == "generate":
start_sentence = [6170, 318, 257]
input_ids = np.array(start_sentence).reshape(1, -1)
outputs = generate(gpt_eval, input_ids, config.seq_length)
output_list = outputs.tolist()
print("output id is ", output_list)
else:
data_path = args.data_path
eval_dataset = create_dataset(config.batch_size, data_path=data_path, drop=False)
if metrics == "acc":
acc = get_acc(gpt_eval, eval_dataset)
print("Accuracy is ", acc)
elif metrics == "ppl":
ppl = get_ppl(gpt_eval, eval_dataset)
print("Perplexity is ", ppl)
if __name__ == "__main__":
run_eval()