|
|
|
@ -18,9 +18,11 @@ Bert evaluation script.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import argparse
|
|
|
|
|
import numpy as np
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
import mindspore.dataset as de
|
|
|
|
|
import mindspore.dataset.transforms.c_transforms as C
|
|
|
|
@ -105,8 +107,17 @@ def bert_predict(Evaluation):
|
|
|
|
|
'''
|
|
|
|
|
prediction function
|
|
|
|
|
'''
|
|
|
|
|
devid = int(os.getenv('DEVICE_ID'))
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
|
|
|
|
target = args_opt.device_target
|
|
|
|
|
if target == "Ascend":
|
|
|
|
|
devid = int(os.getenv('DEVICE_ID'))
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
|
|
|
|
|
elif target == "GPU":
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
|
|
|
|
if bert_net_cfg.compute_type != mstype.float32:
|
|
|
|
|
logger.warning('GPU only support fp32 temporarily, run with fp32.')
|
|
|
|
|
bert_net_cfg.compute_type = mstype.float32
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("Target error, GPU or Ascend is supported.")
|
|
|
|
|
dataset = get_dataset(bert_net_cfg.batch_size, 1)
|
|
|
|
|
if cfg.use_crf:
|
|
|
|
|
net_for_pretraining = Evaluation(bert_net_cfg, False, num_labels=len(tag_to_index), use_crf=True,
|
|
|
|
@ -147,6 +158,9 @@ def test_eval():
|
|
|
|
|
callback.acc_num / callback.total_num))
|
|
|
|
|
print("==============================================================")
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Bert eval')
|
|
|
|
|
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
|
|
|
|
|
args_opt = parser.parse_args()
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
num_labels = cfg.num_labels
|
|
|
|
|
test_eval()
|
|
|
|
|