|
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
|
#
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
# You may obtain a copy of the License at
|
|
|
#
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
#
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
# ============================================================================
|
|
|
"""
|
|
|
CRNN-Seq2Seq-OCR Evaluation.
|
|
|
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import codecs
|
|
|
import argparse
|
|
|
import numpy as np
|
|
|
|
|
|
import mindspore.ops.operations as P
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
|
from mindspore.common import set_seed
|
|
|
from mindspore import context, Tensor
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
|
from src.config import config
|
|
|
from src.utils import initialize_vocabulary
|
|
|
from src.dataset import create_ocr_val_dataset
|
|
|
from src.attention_ocr import AttentionOCRInfer
|
|
|
|
|
|
|
|
|
set_seed(1)
|
|
|
|
|
|
|
|
|
def text_standardization(text_in):
|
|
|
"""
|
|
|
replace some particular characters
|
|
|
"""
|
|
|
stand_text = text_in.strip()
|
|
|
stand_text = ' '.join(stand_text.split())
|
|
|
stand_text = stand_text.replace(u'(', u'(')
|
|
|
stand_text = stand_text.replace(u')', u')')
|
|
|
stand_text = stand_text.replace(u':', u':')
|
|
|
return stand_text
|
|
|
|
|
|
|
|
|
def LCS_length(str1, str2):
|
|
|
"""
|
|
|
calculate longest common sub-sequence between str1 and str2
|
|
|
"""
|
|
|
if str1 is None or str2 is None:
|
|
|
return 0
|
|
|
|
|
|
len1 = len(str1)
|
|
|
len2 = len(str2)
|
|
|
if len1 == 0 or len2 == 0:
|
|
|
return 0
|
|
|
|
|
|
lcs = [[0 for _ in range(len2 + 1)] for _ in range(2)]
|
|
|
for i in range(1, len1 + 1):
|
|
|
for j in range(1, len2 + 1):
|
|
|
if str1[i - 1] == str2[j - 1]:
|
|
|
lcs[i % 2][j] = lcs[(i - 1) % 2][j - 1] + 1
|
|
|
else:
|
|
|
if lcs[i % 2][j - 1] >= lcs[(i - 1) % 2][j]:
|
|
|
lcs[i % 2][j] = lcs[i % 2][j - 1]
|
|
|
else:
|
|
|
lcs[i % 2][j] = lcs[(i - 1) % 2][j]
|
|
|
|
|
|
return lcs[len1 % 2][-1]
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser(description="CRNN-Seq2Seq-OCR Evaluation")
|
|
|
parser.add_argument("--dataset_path", type=str, default="",
|
|
|
help="Test Dataset path")
|
|
|
parser.add_argument("--checkpoint_path", type=str, default=None,
|
|
|
help="Checkpoint of AttentionOCR (Default:None).")
|
|
|
parser.add_argument("--device_target", type=str, default="Ascend",
|
|
|
help="device where the code will be implemented, default is Ascend")
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
|
|
|
|
|
prefix = "fsns.mindrecord"
|
|
|
mindrecord_dir = args.dataset_path
|
|
|
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
|
|
print("mindrecord_file", mindrecord_file)
|
|
|
dataset = create_ocr_val_dataset(mindrecord_file, config.eval_batch_size)
|
|
|
data_loader = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
|
|
|
print("Dataset creation Done!")
|
|
|
|
|
|
#Network
|
|
|
network = AttentionOCRInfer(config.eval_batch_size,
|
|
|
int(config.img_width / 4),
|
|
|
config.encoder_hidden_size,
|
|
|
config.decoder_hidden_size,
|
|
|
config.decoder_output_size,
|
|
|
config.max_length,
|
|
|
config.dropout_p)
|
|
|
|
|
|
ckpt = load_checkpoint(args.checkpoint_path)
|
|
|
load_param_into_net(network, ckpt)
|
|
|
network.set_train(False)
|
|
|
print("Checkpoint loading Done!")
|
|
|
|
|
|
vocab, rev_vocab = initialize_vocabulary(config.vocab_path)
|
|
|
eos_id = config.characters_dictionary.get("eos_id")
|
|
|
sos_id = config.characters_dictionary.get("go_id")
|
|
|
|
|
|
num_correct_char = 0
|
|
|
num_total_char = 0
|
|
|
num_correct_word = 0
|
|
|
num_total_word = 0
|
|
|
|
|
|
correct_file = 'result_correct.txt'
|
|
|
incorrect_file = 'result_incorrect.txt'
|
|
|
|
|
|
with codecs.open(correct_file, 'w', encoding='utf-8') as fp_output_correct, \
|
|
|
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
|
|
|
|
|
|
for data in data_loader:
|
|
|
images = Tensor(data["image"])
|
|
|
decoder_inputs = Tensor(data["decoder_input"])
|
|
|
decoder_targets = Tensor(data["decoder_target"])
|
|
|
|
|
|
decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size),
|
|
|
dtype=np.float16), mstype.float16)
|
|
|
decoder_input = Tensor((np.ones((config.eval_batch_size, 1))*sos_id).astype(np.int32))
|
|
|
encoder_outputs = network.encoder(images)
|
|
|
batch_decoded_label = []
|
|
|
|
|
|
for di in range(decoder_inputs.shape[1]):
|
|
|
decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs)
|
|
|
topi = P.Argmax()(decoder_output)
|
|
|
ni = P.ExpandDims()(topi, 1)
|
|
|
decoder_input = ni
|
|
|
topi_id = topi.asnumpy()
|
|
|
batch_decoded_label.append(topi_id)
|
|
|
|
|
|
for b in range(config.eval_batch_size):
|
|
|
text = data["annotation"][b].decode("utf8")
|
|
|
text = text_standardization(text)
|
|
|
decoded_label = list(np.array(batch_decoded_label)[:, b])
|
|
|
decoded_words = []
|
|
|
for idx in decoded_label:
|
|
|
if idx == eos_id:
|
|
|
break
|
|
|
else:
|
|
|
decoded_words.append(rev_vocab[idx])
|
|
|
predict = text_standardization("".join(decoded_words))
|
|
|
|
|
|
if predict == text:
|
|
|
num_correct_word += 1
|
|
|
fp_output_correct.write('\t\t' + text + '\n')
|
|
|
fp_output_correct.write('\t\t' + predict + '\n\n')
|
|
|
print('correctly predicted : pred: {}, gt: {}'.format(predict, text))
|
|
|
|
|
|
else:
|
|
|
fp_output_incorrect.write('\t\t' + text + '\n')
|
|
|
fp_output_incorrect.write('\t\t' + predict + '\n\n')
|
|
|
print('incorrectly predicted : pred: {}, gt: {}'.format(predict, text))
|
|
|
|
|
|
num_total_word += 1
|
|
|
num_correct_char += 2 * LCS_length(text, predict)
|
|
|
num_total_char += len(text) + len(predict)
|
|
|
|
|
|
print('\nnum of correct characters = %d' % (num_correct_char))
|
|
|
print('\nnum of total characters = %d' % (num_total_char))
|
|
|
print('\nnum of correct words = %d' % (num_correct_word))
|
|
|
print('\nnum of total words = %d' % (num_total_word))
|
|
|
print('\ncharacter precision = %f' % (float(num_correct_char) / num_total_char))
|
|
|
print('\nAnnotation precision precision = %f' % (float(num_correct_word) / num_total_word))
|