Merge pull request #1597 from tink2123/dygraph_for_srn

【Do not merge】Add srn model
revert-1929-fix_typo
xiaoting 4 years ago committed by GitHub
commit acd479ea46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
@ -59,7 +59,7 @@ Metric:
Train: Train:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/ data_dir: ./train_data/data_lmdb_release/training/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
@ -78,7 +78,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/validation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image

@ -58,7 +58,7 @@ Metric:
Train: Train:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/ data_dir: ./train_data/data_lmdb_release/training/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
@ -77,7 +77,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/validation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image

@ -63,7 +63,7 @@ Metric:
Train: Train:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/ data_dir: ./train_data/data_lmdb_release/training/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
@ -82,7 +82,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/validation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image

@ -58,7 +58,7 @@ Metric:
Train: Train:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/ data_dir: ./train_data/data_lmdb_release/training/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
@ -77,7 +77,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/validation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image

@ -56,7 +56,7 @@ Metric:
Train: Train:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/ data_dir: ./train_data/data_lmdb_release/training/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
@ -75,7 +75,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/validation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image

@ -62,7 +62,7 @@ Metric:
Train: Train:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/ data_dir: ./train_data/data_lmdb_release/training/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
@ -81,7 +81,7 @@ Train:
Eval: Eval:
dataset: dataset:
name: LMDBDateSet name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/ data_dir: ./train_data/data_lmdb_release/validation/
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image

@ -0,0 +1,107 @@
Global:
use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/rec/srn_new
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 5000]
# if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path:
character_type: en
max_text_length: 25
num_heads: 8
infer_mode: False
use_space_char: False
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
clip_norm: 10.0
lr:
learning_rate: 0.0001
Architecture:
model_type: rec
algorithm: SRN
in_channels: 1
Transform:
Backbone:
name: ResNetFPN
Head:
name: SRNHead
max_text_length: 25
num_heads: 8
num_encoder_TUs: 2
num_decoder_TUs: 4
hidden_dims: 512
Loss:
name: SRNLoss
PostProcess:
name: SRNLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: LMDBDataSet
data_dir: ./train_data/srn_train_data_duiqi
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SRNLabelEncode: # Class handling label
- SRNRecResizeImg:
image_shape: [1, 64, 256]
- KeepKeys:
keep_keys: ['image',
'label',
'length',
'encoder_word_pos',
'gsrm_word_pos',
'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2'] # dataloader will return list in this order
loader:
shuffle: False
batch_size_per_card: 64
drop_last: False
num_workers: 4
Eval:
dataset:
name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/evaluation
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- SRNLabelEncode: # Class handling label
- SRNRecResizeImg:
image_shape: [1, 64, 256]
- KeepKeys:
keep_keys: ['image',
'label',
'length',
'encoder_word_pos',
'gsrm_word_pos',
'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 32
num_workers: 4

@ -33,7 +33,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet from ppocr.data.simple_dataset import SimpleDataSet
from ppocr.data.lmdb_dataset import LMDBDateSet from ppocr.data.lmdb_dataset import LMDBDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators'] __all__ = ['build_dataloader', 'transform', 'create_operators']
@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp)
def build_dataloader(config, mode, device, logger, seed=None): def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config) config = copy.deepcopy(config)
support_dict = ['SimpleDataSet', 'LMDBDateSet'] support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name'] module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict)) 'DataSet only support {}'.format(support_dict))

@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment from .randaugment import RandAugment
from .operators import * from .operators import *
from .label_ops import * from .label_ops import *

@ -102,6 +102,8 @@ class BaseRecLabelEncode(object):
support_character_type, character_type) support_character_type, character_type)
self.max_text_len = max_text_length self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en": if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
@ -217,3 +219,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end % beg_or_end
return idx return idx
class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length=25,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def __call__(self, data):
text = data['label']
text = self.encode(text)
char_num = len(self.character_str)
if text is None:
return None
if len(text) > self.max_text_len:
return None
data['length'] = np.array(len(text))
text = text + [char_num] * (self.max_text_len - len(text))
data['label'] = np.array(text)
return data
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx

@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import cv2 import cv2
import numpy as np import numpy as np
@ -77,6 +63,26 @@ class RecResizeImg(object):
return data return data
class SRNRecResizeImg(object):
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
self.image_shape = image_shape
self.num_heads = num_heads
self.max_text_length = max_text_length
def __call__(self, data):
img = data['image']
norm_img = resize_norm_img_srn(img, self.image_shape)
data['image'] = norm_img
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
data['encoder_word_pos'] = encoder_word_pos
data['gsrm_word_pos'] = gsrm_word_pos
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
return data
def resize_norm_img(img, image_shape): def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
h = img.shape[0] h = img.shape[0]
@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
def resize_norm_img_chinese(img, image_shape): def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape # todo: change to 0 and modified image shape
max_wh_ratio = 0 max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1] h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio) max_wh_ratio = max(max_wh_ratio, ratio)
@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
return padding_im return padding_im
def resize_norm_img_srn(img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
[num_heads, 1, 1]) * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
[num_heads, 1, 1]) * [-1e9]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def flag(): def flag():
""" """
flag flag

@ -23,11 +23,14 @@ def build_loss(config):
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .rec_srn_loss import SRNLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss'] support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss'
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')

@ -0,0 +1,47 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
class SRNLoss(nn.Layer):
def __init__(self, **kwargs):
super(SRNLoss, self).__init__()
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
def forward(self, predicts, batch):
predict = predicts['predict']
word_predict = predicts['word_out']
gsrm_predict = predicts['gsrm_out']
label = batch[1]
casted_label = paddle.cast(x=label, dtype='int64')
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
cost_word = self.loss_func(word_predict, label=casted_label)
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
cost_vsfd = self.loss_func(predict, label=casted_label)
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}

@ -33,8 +33,6 @@ class RecMetric(object):
if pred == target: if pred == target:
correct_num += 1 correct_num += 1
all_num += 1 all_num += 1
# if all_num < 10 and kwargs.get('show_str', False):
# print('{} -> {}'.format(pred, target))
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
self.norm_edit_dis += norm_edit_dis self.norm_edit_dis += norm_edit_dis
@ -50,7 +48,7 @@ class RecMetric(object):
'norm_edit_dis': 0, 'norm_edit_dis': 0,
} }
""" """
acc = self.correct_num / self.all_num acc = 1.0 * self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
self.reset() self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis} return {'acc': acc, 'norm_edit_dis': norm_edit_dis}

@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
def forward(self, x): def forward(self, x, data=None):
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
x = self.backbone(x) x = self.backbone(x)
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
if data is None:
x = self.head(x) x = self.head(x)
else:
x = self.head(x, data)
return x return x

@ -24,7 +24,8 @@ def build_backbone(config, model_type):
elif model_type == 'rec' or model_type == 'cls': elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN'] from .rec_resnet_fpn import ResNetFPN
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
else: else:
raise NotImplementedError raise NotImplementedError

File diff suppressed because it is too large Load Diff

@ -23,10 +23,13 @@ def build_head(config):
# rec head # rec head
from .rec_ctc_head import CTCHead from .rec_ctc_head import CTCHead
from .rec_srn_head import SRNHead
# cls head # cls head
from .cls_head import ClsHead from .cls_head import ClsHead
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead'] support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead'
]
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format( assert module_name in support_dict, Exception('head only support {}'.format(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -26,11 +26,12 @@ def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)

@ -33,6 +33,9 @@ class BaseRecLabelDecode(object):
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
self.beg_str = "sos"
self.end_str = "eos"
if character_type == "en": if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
@ -109,7 +112,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
preds_idx = preds.argmax(axis=2) preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2) preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
@ -158,3 +160,84 @@ class AttnLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \ assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end % beg_or_end
return idx return idx
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
def __init__(self,
character_dict_path=None,
character_type='en',
use_space_char=False,
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict']
char_num = len(self.character_str) + 2
if isinstance(pred, paddle.Tensor):
pred = pred.numpy()
pred = np.reshape(pred, [-1, char_num])
preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1)
preds_idx = np.reshape(preds_idx, [-1, 25])
preds_prob = np.reshape(preds_prob, [-1, 25])
text = self.decode(preds_idx, preds_prob)
if label is None:
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
return text
label = self.decode(label)
return text, label
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] in ignored_tokens:
continue
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text, np.mean(conf_list)))
return result_list
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
return dict_character
def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx

@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="configuration file to use")
parser.add_argument(
"-o", "--output_path", type=str, default='./output/infer/')
return parser.parse_args()
def main(): def main():
FLAGS = ArgsParser().parse_args() FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
@ -52,6 +60,22 @@ def main():
save_path = '{}/inference'.format(config['Global']['save_inference_dir']) save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN":
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [
paddle.static.InputSpec(
shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 25, 1],
dtype="int64"), paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64"),
paddle.static.InputSpec(
shape=[None, 8, 25, 25], dtype="int64")
]
]
model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec": if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32 infer_shape = [3, 32, -1] # for rec model, H must be 32
@ -62,13 +86,13 @@ def main():
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
) )
infer_shape[-1] = 100 infer_shape[-1] = 100
model = to_static( model = to_static(
model, model,
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32') shape=[None] + infer_shape, dtype='float32')
]) ])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info('inference model is saved to {}'.format(save_path))

@ -25,6 +25,7 @@ import numpy as np
import math import math
import time import time
import traceback import traceback
import paddle
import tools.infer.utility as utility import tools.infer.utility as utility
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
@ -46,6 +47,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path, "character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char "use_space_char": args.use_space_char
} }
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \ self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
@ -70,6 +78,78 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image padding_im[:, :, 0:resized_w] = resized_image
return padding_im return padding_im
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
encoder_word_pos = encoder_word_pos.astype(np.int64)
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def __call__(self, img_list): def __call__(self, img_list):
img_num = len(img_list) img_num = len(img_list)
# Calculate the aspect ratio of all text bars # Calculate the aspect ratio of all text bars
@ -93,21 +173,64 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio) max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no): for ino in range(beg_img_no, end_img_no):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) if self.rec_algorithm != "SRN":
norm_img = self.resize_norm_img(img_list[indices[ino]], norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio) max_wh_ratio)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img) norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
if self.rec_algorithm == "SRN":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch,
encoder_word_pos_list,
gsrm_word_pos_list,
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[
i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {"predict": outputs[2]}
else:
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run() self.predictor.run()
outputs = [] outputs = []
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
preds = outputs[0] preds = outputs[0]
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]

@ -62,6 +62,12 @@ def main():
elif op_name in ['RecResizeImg']: elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
if config['Architecture']['algorithm'] == "SRN":
op[op_name]['keep_keys'] = [
'image', 'encoder_word_pos', 'gsrm_word_pos',
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
else:
op[op_name]['keep_keys'] = ['image'] op[op_name]['keep_keys'] = ['image']
transforms.append(op) transforms.append(op)
global_config['infer_mode'] = True global_config['infer_mode'] = True
@ -74,9 +80,24 @@ def main():
img = f.read() img = f.read()
data = {'image': img} data = {'image': img}
batch = transform(data, ops) batch = transform(data, ops)
if config['Architecture']['algorithm'] == "SRN":
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
others = [
paddle.to_tensor(encoder_word_pos_list),
paddle.to_tensor(gsrm_word_pos_list),
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
images = np.expand_dims(batch[0], axis=0) images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
else:
preds = model(images) preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds)
for rec_reuslt in post_result: for rec_reuslt in post_result:

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save