|
|
|
@ -214,6 +214,8 @@ class SimpleReader(object):
|
|
|
|
|
self.mode = params['mode']
|
|
|
|
|
self.infer_img = params['infer_img']
|
|
|
|
|
self.use_tps = False
|
|
|
|
|
if "num_heads" in params:
|
|
|
|
|
self.num_heads = params['num_heads']
|
|
|
|
|
if "tps" in params:
|
|
|
|
|
self.use_tps = True
|
|
|
|
|
self.use_distort = False
|
|
|
|
@ -251,12 +253,19 @@ class SimpleReader(object):
|
|
|
|
|
img = cv2.imread(single_img)
|
|
|
|
|
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
|
norm_img = process_image(
|
|
|
|
|
img=img,
|
|
|
|
|
image_shape=self.image_shape,
|
|
|
|
|
char_ops=self.char_ops,
|
|
|
|
|
tps=self.use_tps,
|
|
|
|
|
infer_mode=True)
|
|
|
|
|
if self.loss_type == 'srn':
|
|
|
|
|
norm_img = process_image_srn(
|
|
|
|
|
img=img,
|
|
|
|
|
image_shape=self.image_shape,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
max_text_length=self.max_text_length)
|
|
|
|
|
else:
|
|
|
|
|
norm_img = process_image(
|
|
|
|
|
img=img,
|
|
|
|
|
image_shape=self.image_shape,
|
|
|
|
|
char_ops=self.char_ops,
|
|
|
|
|
tps=self.use_tps,
|
|
|
|
|
infer_mode=True)
|
|
|
|
|
yield norm_img
|
|
|
|
|
else:
|
|
|
|
|
with open(self.label_file_path, "rb") as fin:
|
|
|
|
@ -286,14 +295,25 @@ class SimpleReader(object):
|
|
|
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
|
|
|
|
|
|
label = substr[1]
|
|
|
|
|
outs = process_image(
|
|
|
|
|
img=img,
|
|
|
|
|
image_shape=self.image_shape,
|
|
|
|
|
label=label,
|
|
|
|
|
char_ops=self.char_ops,
|
|
|
|
|
loss_type=self.loss_type,
|
|
|
|
|
max_text_length=self.max_text_length,
|
|
|
|
|
distort=self.use_distort)
|
|
|
|
|
if self.loss_type == "srn":
|
|
|
|
|
outs = process_image_srn(
|
|
|
|
|
img=img,
|
|
|
|
|
image_shape=self.image_shape,
|
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
|
max_text_length=self.max_text_length,
|
|
|
|
|
label=label,
|
|
|
|
|
char_ops=self.char_ops,
|
|
|
|
|
loss_type=self.loss_type)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
outs = process_image(
|
|
|
|
|
img=img,
|
|
|
|
|
image_shape=self.image_shape,
|
|
|
|
|
label=label,
|
|
|
|
|
char_ops=self.char_ops,
|
|
|
|
|
loss_type=self.loss_type,
|
|
|
|
|
max_text_length=self.max_text_length,
|
|
|
|
|
distort=self.use_distort)
|
|
|
|
|
if outs is None:
|
|
|
|
|
continue
|
|
|
|
|
yield outs
|
|
|
|
|