|
|
|
@ -15,19 +15,17 @@
|
|
|
|
|
import utility
|
|
|
|
|
from ppocr.utils.utility import initial_logger
|
|
|
|
|
logger = initial_logger()
|
|
|
|
|
from ppocr.utils.utility import get_image_file_list
|
|
|
|
|
import cv2
|
|
|
|
|
from ppocr.data.det.east_process import EASTProcessTest
|
|
|
|
|
from ppocr.data.det.db_process import DBProcessTest
|
|
|
|
|
from ppocr.postprocess.db_postprocess import DBPostProcess
|
|
|
|
|
from ppocr.postprocess.east_postprocess import EASTPostPocess
|
|
|
|
|
from ppocr.utils.utility import get_image_file_list
|
|
|
|
|
from tools.infer.utility import draw_ocr
|
|
|
|
|
import copy
|
|
|
|
|
import numpy as np
|
|
|
|
|
import math
|
|
|
|
|
import time
|
|
|
|
|
import sys
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextDetector(object):
|
|
|
|
@ -79,27 +77,10 @@ class TextDetector(object):
|
|
|
|
|
rect = np.array([tl, tr, br, bl], dtype="float32")
|
|
|
|
|
return rect
|
|
|
|
|
|
|
|
|
|
def expand_det_res(self, points, bbox_height, bbox_width, img_height,
|
|
|
|
|
img_width):
|
|
|
|
|
if bbox_height * 1.0 / bbox_width >= 2.0:
|
|
|
|
|
expand_w = bbox_width * 0.20
|
|
|
|
|
expand_h = bbox_width * 0.20
|
|
|
|
|
elif bbox_width * 1.0 / bbox_height >= 3.0:
|
|
|
|
|
expand_w = bbox_height * 0.20
|
|
|
|
|
expand_h = bbox_height * 0.20
|
|
|
|
|
else:
|
|
|
|
|
expand_w = bbox_height * 0.1
|
|
|
|
|
expand_h = bbox_height * 0.1
|
|
|
|
|
|
|
|
|
|
points[0, 0] = int(max((points[0, 0] - expand_w), 0))
|
|
|
|
|
points[1, 0] = int(min((points[1, 0] + expand_w), img_width))
|
|
|
|
|
points[3, 0] = int(max((points[3, 0] - expand_w), 0))
|
|
|
|
|
points[2, 0] = int(min((points[2, 0] + expand_w), img_width))
|
|
|
|
|
|
|
|
|
|
points[0, 1] = int(max((points[0, 1] - expand_h), 0))
|
|
|
|
|
points[1, 1] = int(max((points[1, 1] - expand_h), 0))
|
|
|
|
|
points[3, 1] = int(min((points[3, 1] + expand_h), img_height))
|
|
|
|
|
points[2, 1] = int(min((points[2, 1] + expand_h), img_height))
|
|
|
|
|
def clip_det_res(self, points, img_height, img_width):
|
|
|
|
|
for pno in range(4):
|
|
|
|
|
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
|
|
|
|
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
|
|
|
|
return points
|
|
|
|
|
|
|
|
|
|
def filter_tag_det_res(self, dt_boxes, image_shape):
|
|
|
|
@ -107,22 +88,11 @@ class TextDetector(object):
|
|
|
|
|
dt_boxes_new = []
|
|
|
|
|
for box in dt_boxes:
|
|
|
|
|
box = self.order_points_clockwise(box)
|
|
|
|
|
left = int(np.min(box[:, 0]))
|
|
|
|
|
right = int(np.max(box[:, 0]))
|
|
|
|
|
top = int(np.min(box[:, 1]))
|
|
|
|
|
bottom = int(np.max(box[:, 1]))
|
|
|
|
|
bbox_height = bottom - top
|
|
|
|
|
bbox_width = right - left
|
|
|
|
|
diffh = math.fabs(box[0, 1] - box[1, 1])
|
|
|
|
|
diffw = math.fabs(box[0, 0] - box[3, 0])
|
|
|
|
|
box = self.clip_det_res(box, img_height, img_width)
|
|
|
|
|
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
|
|
|
|
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
|
|
|
|
if rect_width <= 10 or rect_height <= 10:
|
|
|
|
|
continue
|
|
|
|
|
# if diffh <= 10 and diffw <= 10:
|
|
|
|
|
# box = self.expand_det_res(
|
|
|
|
|
# copy.deepcopy(box), bbox_height, bbox_width, img_height,
|
|
|
|
|
# img_width)
|
|
|
|
|
dt_boxes_new.append(box)
|
|
|
|
|
dt_boxes = np.array(dt_boxes_new)
|
|
|
|
|
return dt_boxes
|
|
|
|
@ -153,8 +123,6 @@ class TextDetector(object):
|
|
|
|
|
return dt_boxes, elapse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from tools.infer.utility import draw_text_det_res
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
args = utility.parse_args()
|
|
|
|
|
image_file_list = get_image_file_list(args.image_dir)
|
|
|
|
@ -171,9 +139,8 @@ if __name__ == "__main__":
|
|
|
|
|
total_time += elapse
|
|
|
|
|
count += 1
|
|
|
|
|
print("Predict time of %s:" % image_file, elapse)
|
|
|
|
|
img_draw = draw_text_det_res(dt_boxes, image_file, return_img=True)
|
|
|
|
|
save_path = os.path.join("./inference_det/",
|
|
|
|
|
os.path.basename(image_file))
|
|
|
|
|
print("The visualized image saved in {}".format(save_path))
|
|
|
|
|
|
|
|
|
|
print("Avg Time:", total_time / (count - 1))
|
|
|
|
|
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
|
|
|
|
img_name_pure = image_file.split("/")[-1]
|
|
|
|
|
cv2.imwrite("./inference_results/det_res_%s" % img_name_pure, src_im)
|
|
|
|
|
if count > 1:
|
|
|
|
|
print("Avg Time:", total_time / (count - 1))
|
|
|
|
|