Merge pull request #1597 from tink2123/dygraph_for_srn

【Do not merge】Add srn model
revert-1929-fix_typo
xiaoting 4 years ago committed by GitHub
commit acd479ea46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
Global:
use_gpu: true
use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
@ -59,7 +59,7 @@ Metric:
Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@ -78,7 +78,7 @@ Train:
Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image

@ -58,7 +58,7 @@ Metric:
Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@ -77,7 +77,7 @@ Train:
Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image

@ -63,7 +63,7 @@ Metric:
Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@ -82,7 +82,7 @@ Train:
Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image

@ -58,7 +58,7 @@ Metric:
Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@ -77,7 +77,7 @@ Train:
Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image

@ -56,7 +56,7 @@ Metric:
Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@ -75,7 +75,7 @@ Train:
Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image

@ -62,7 +62,7 @@ Metric:
Train:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@ -81,7 +81,7 @@ Train:
Eval:
dataset:
name: LMDBDateSet
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image

@ -0,0 +1,107 @@
Global:
use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/rec/srn_new
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
num_heads: 8
infer_mode: False
use_space_char: False
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 10.0
lr:
learning_rate: 0.0001
Architecture:
model_type: rec
algorithm: SRN
in_channels: 1
Transform:
Backbone:
name: ResNetFPN
Head:
name: SRNHead
max_text_length: 25
num_heads: 8
num_encoder_TUs: 2
num_decoder_TUs: 4
hidden_dims: 512
Loss:
name: SRNLoss
PostProcess:
name: SRNLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/srn_train_data_duiqi
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SRNLabelEncode: # Class handling label
- SRNRecResizeImg:
image_shape: [1, 64, 256]
- KeepKeys:
keep_keys: ['image',
'label',
'length',
'encoder_word_pos',
'gsrm_word_pos',
'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2'] # dataloader will return list in this order
loader:
shuffle: False
batch_size_per_card: 64
drop_last: False
num_workers: 4
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SRNLabelEncode: # Class handling label
- SRNRecResizeImg:
image_shape: [1, 64, 256]
- KeepKeys:
keep_keys: ['image',
'label',
'length',
'encoder_word_pos',
'gsrm_word_pos',
'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 32
num_workers: 4

@ -33,7 +33,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDateSet
from ppocr.data.lmdb_dataset import LMDBDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet']
support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))

@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment
from .operators import *
from .label_ops import *

@ -102,6 +102,8 @@ class BaseRecLabelEncode(object):
support_character_type, character_type)
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
@ -217,3 +219,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length=25,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
char_num = len(self.character_str)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
text = text + [char_num] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx

@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
@ -77,6 +63,26 @@ class RecResizeImg(object):
return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
self.num_heads = num_heads
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_srn(img, self.image_shape)
data['image'] = norm_img
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
data['encoder_word_pos'] = encoder_word_pos
data['gsrm_word_pos'] = gsrm_word_pos
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
return data
def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape
h = img.shape[0]
@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = 0
max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
return padding_im
def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
[num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
[num_heads, 1, 1]) * [-1e9]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def flag():
"""
flag

@ -23,11 +23,14 @@ def build_loss(config):
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_srn_loss import SRNLoss
# cls loss
from .cls_loss import ClsLoss
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')

@ -0,0 +1,47 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class SRNLoss(nn.Layer):
def __init__(self, **kwargs):
super(SRNLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
def forward(self, predicts, batch):
predict = predicts['predict']
word_predict = predicts['word_out']
gsrm_predict = predicts['gsrm_out']
label = batch[1]
casted_label = paddle.cast(x=label, dtype='int64')
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
cost_word = self.loss_func(word_predict, label=casted_label)
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
cost_vsfd = self.loss_func(predict, label=casted_label)
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}

@ -33,8 +33,6 @@ class RecMetric(object):
if pred == target:
correct_num += 1
all_num += 1
# if all_num < 10 and kwargs.get('show_str', False):
# print('{} -> {}'.format(pred, target))
self.correct_num += correct_num
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
@ -50,7 +48,7 @@ class RecMetric(object):
'norm_edit_dis': 0,
}
"""
acc = self.correct_num / self.all_num
acc = 1.0 * self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}

@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
def forward(self, x):
def forward(self, x, data=None):
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
if self.use_neck:
x = self.neck(x)
x = self.head(x)
if data is None:
x = self.head(x)
else:
x = self.head(x, data)
return x

@ -24,7 +24,8 @@ def build_backbone(config, model_type):
elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN']
from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
else:
raise NotImplementedError

File diff suppressed because it is too large Load Diff

@ -23,10 +23,13 @@ def build_head(config):
# rec head
from .rec_ctc_head import CTCHead
from .rec_srn_head import SRNHead
# cls head
from .cls_head import ClsHead
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead'
]
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -26,11 +26,12 @@ def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
]
config = copy.deepcopy(config)

@ -33,6 +33,9 @@ class BaseRecLabelDecode(object):
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
@ -109,7 +112,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
@ -158,3 +160,84 @@ class AttnLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict']
char_num = len(self.character_str) + 2
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = np.reshape(pred, [-1, char_num])
preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1)
preds_idx = np.reshape(preds_idx, [-1, 25])
preds_prob = np.reshape(preds_prob, [-1, 25])
text = self.decode(preds_idx, preds_prob)
if label is None:
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
return text
label = self.decode(label)
return text, label
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx

@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-o", "--output_path", type=str, default='./output/infer/')
return parser.parse_args()
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
@ -52,23 +60,39 @@ def main():
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
if config['Architecture']['algorithm'] == "SRN":
other_shape = [
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
shape=[None, 1, 64, 256], dtype='float32'), [
paddle.static.InputSpec(
shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 25, 1],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64"),
paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64")
]
]
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))

@ -25,6 +25,7 @@ import numpy as np
import math
import time
import traceback
import paddle
import tools.infer.utility as utility
from ppocr.postprocess import build_post_process
@ -46,6 +47,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'rec', logger)
@ -70,6 +78,78 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@ -93,21 +173,64 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
if self.rec_algorithm != "SRN":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
if self.rec_algorithm == "SRN":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch,
encoder_word_pos_list,
gsrm_word_pos_list,
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {"predict": outputs[2]}
else:
starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]

@ -62,7 +62,13 @@ def main():
elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image']
if config['Architecture']['algorithm'] == "SRN":
op[op_name]['keep_keys'] = [
'image', 'encoder_word_pos', 'gsrm_word_pos',
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
@ -74,10 +80,25 @@ def main():
img = f.read()
data = {'image': img}
batch = transform(data, ops)
if config['Architecture']['algorithm'] == "SRN":
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
others = [
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
else:
preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save