# Copyright 2020 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """ GPT evaluation script. """ import math import argparse import numpy as np from mindspore import context import mindspore.common.dtype as mstype from mindspore.common.tensor import Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net from src.inference import generate from src.dataset import create_dataset from src.gpt import GPT, EvalNet, GPTWithLoss, CrossEntropyLoss from src.utils import GPTConfig context.set_context(mode=context.GRAPH_MODE) def ppl_score(probs, length, is_logsoftmax=True): """ calculate perplexity with prob or log_prob inputs """ probs = probs[:length] if is_logsoftmax: prob = np.sum(probs) / length ppl = 1.0 / np.power(np.e, prob) else: prob = 1.0 for p in probs: prob *= (1.0 / p) ppl = np.power(prob, 1.0/length) return ppl def get_ppl(model, dataset): """ calculate perplexity for input dataset """ PPL = [] tokens = 0 for data in dataset: data = data[0].asnumpy() input_ids = data logits = model(Tensor(input_ids, mstype.int32)).asnumpy() PPL.append(logits * len(data)) tokens += len(data) val_loss = sum(PPL) / tokens ppl = math.exp(min(20, val_loss)) return ppl def get_acc(model, dataset): """ calculate accuracy for input dataset """ total_num = 0 acc_num = 0 for data in dataset: data = data[0].asnumpy() input_mask = (data != 0).astype(np.int32) length = np.sum(input_mask, 1) label = np.zeros(length.shape) for i, idx in enumerate(length): label[i] = data[i][idx-1] input_mask[i][idx-1] = 0 data[i][idx-1] = 0 length = np.sum(data != 50256, 1) input_ids = data logits = model(Tensor(input_ids, mstype.int32)).asnumpy() logits = logits.reshape(len(length), -1) predicted_label = np.zeros(length.shape) for i, idx in enumerate(length): predicted_label[i] = logits[i][idx-2] total_num += len(label) acc_num += sum(label == predicted_label) acc = acc_num / total_num return acc def run_eval(): """ evaluate scripts """ parser = argparse.ArgumentParser(description="GPT inferencing") parser.add_argument('--task_type', type=str, default="", help="Evaluation task.") parser.add_argument('--metrics', type=str, default="acc", choices=["ppl", "acc"], help="Evaluation metrics.") parser.add_argument('--ckpt_path', type=str, default="", help="path of checkpoint file.") parser.add_argument('--data_path', type=str, default="", help="path of MindRecord file.") args = parser.parse_args() task = args.task_type metrics = args.metrics ckpt_path = args.ckpt_path if task not in ["generate", "lambada", "wikitext"]: raise ValueError("{} is not supported now".format(task)) if metrics not in ["acc", "ppl"]: raise ValueError("{} is not supported now".format(metrics)) config = GPTConfig(batch_size=16, seq_length=1024, vocab_size=50257, embedding_size=1024, num_layers=24, num_heads=16, expand_ratio=4, post_layernorm_residual=False, dropout_rate=0.0, compute_dtype=mstype.float16, use_past=False) ckpt_dict = load_checkpoint(ckpt_path) gpt = GPT(config) if task == "generate": gpt_eval = EvalNet(gpt, generate=True) elif metrics == "acc": gpt_eval = EvalNet(gpt, generate=False) else: loss = CrossEntropyLoss(config) gpt_eval = GPTWithLoss(gpt, loss) gpt_eval.set_train(False) load_param_into_net(gpt_eval, ckpt_dict) if task == "generate": start_sentence = [6170, 318, 257] input_ids = np.array(start_sentence).reshape(1, -1) outputs = generate(gpt_eval, input_ids, config.seq_length) output_list = outputs.tolist() print("output id is ", output_list) else: data_path = args.data_path eval_dataset = create_dataset(config.batch_size, data_path=data_path, drop=False) if metrics == "acc": acc = get_acc(gpt_eval, eval_dataset) print("Accuracy is ", acc) elif metrics == "ppl": ppl = get_ppl(gpt_eval, eval_dataset) print("Perplexity is ", ppl) if __name__ == "__main__": run_eval()