You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/research/nlp/ternarybert/eval.py

108 lines
4.5 KiB

# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval standalone script"""
import os
import re
import argparse
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.config import eval_cfg, student_net_cfg, task_cfg
from src.tinybert_model import BertModelCLS
DATA_NAME = 'eval.tf_record'
def parse_args():
"""
parse args
"""
parser = argparse.ArgumentParser(description='ternarybert evaluation')
parser.add_argument('--device_target', type=str, default='GPU', choices=['Ascend', 'GPU'],
help='Device where the code will be implemented. (Default: GPU)')
parser.add_argument('--device_id', type=int, default=0, help='Device id. (Default: 0)')
parser.add_argument('--model_dir', type=str, default='', help='The checkpoint directory of model.')
parser.add_argument('--data_dir', type=str, default='', help='Data directory.')
parser.add_argument('--task_name', type=str, default='sts-b', choices=['sts-b', 'qnli', 'mnli'],
help='The name of the task to train. (Default: sts-b)')
parser.add_argument('--dataset_type', type=str, default='tfrecord', choices=['tfrecord', 'mindrecord'],
help='The name of the task to train. (Default: tfrecord)')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for evaluating')
return parser.parse_args()
def get_ckpt(ckpt_file):
lists = os.listdir(ckpt_file)
lists.sort(key=lambda fn: os.path.getmtime(ckpt_file + '/' + fn))
return os.path.join(ckpt_file, lists[-1])
def do_eval_standalone(args_opt):
"""
do eval standalone
"""
ckpt_file = os.path.join(args_opt.model_dir, args_opt.task_name)
ckpt_file = get_ckpt(ckpt_file)
print('ckpt file:', ckpt_file)
task = task_cfg[args_opt.task_name]
student_net_cfg.seq_length = task.seq_length
eval_cfg.batch_size = args_opt.batch_size
eval_data_dir = os.path.join(args_opt.data_dir, args_opt.task_name, DATA_NAME)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args.device_id)
eval_dataset = create_dataset(batch_size=eval_cfg.batch_size,
device_num=1,
rank=0,
do_shuffle='false',
data_dir=eval_data_dir,
data_type=args_opt.dataset_type,
seq_length=task.seq_length,
task_type=task.task_type,
drop_remainder=False)
print('eval dataset size:', eval_dataset.get_dataset_size())
print('eval dataset batch size:', eval_dataset.get_batch_size())
eval_model = BertModelCLS(student_net_cfg, False, task.num_labels, 0.0, phase_type='student')
param_dict = load_checkpoint(ckpt_file)
new_param_dict = {}
for key, value in param_dict.items():
new_key = re.sub('tinybert_', 'bert_', key)
new_key = re.sub('^bert.', '', new_key)
new_param_dict[new_key] = value
load_param_into_net(eval_model, new_param_dict)
eval_model.set_train(False)
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
callback = task.metrics()
for step, data in enumerate(eval_dataset.create_dict_iterator()):
input_data = []
for i in columns_list:
input_data.append(data[i])
input_ids, input_mask, token_type_id, label_ids = input_data
_, _, logits, _ = eval_model(input_ids, token_type_id, input_mask)
callback.update(logits, label_ids)
print('eval step: {}, {}: {}'.format(step, callback.name, callback.get_metrics()))
metrics = callback.get_metrics()
print('The best {}: {}'.format(callback.name, metrics))
if __name__ == '__main__':
args = parse_args()
do_eval_standalone(args)