|
|
|
@ -41,19 +41,19 @@ def draw_server_result(image_file, res):
|
|
|
|
|
if len(res) == 0:
|
|
|
|
|
return np.array(image)
|
|
|
|
|
keys = res[0].keys()
|
|
|
|
|
if 'text_region' not in keys: # for ocr_rec, draw function is invalid
|
|
|
|
|
print("draw function is invalid for ocr_rec!")
|
|
|
|
|
if 'text_region' not in keys: # for ocr_rec, draw function is invalid
|
|
|
|
|
logger.info("draw function is invalid for ocr_rec!")
|
|
|
|
|
return None
|
|
|
|
|
elif 'text' not in keys: # for ocr_det
|
|
|
|
|
print("draw text boxes only!")
|
|
|
|
|
elif 'text' not in keys: # for ocr_det
|
|
|
|
|
logger.info("draw text boxes only!")
|
|
|
|
|
boxes = []
|
|
|
|
|
for dno in range(len(res)):
|
|
|
|
|
boxes.append(res[dno]['text_region'])
|
|
|
|
|
boxes = np.array(boxes)
|
|
|
|
|
draw_img = draw_boxes(image, boxes)
|
|
|
|
|
return draw_img
|
|
|
|
|
else: # for ocr_system
|
|
|
|
|
print("draw boxes and texts!")
|
|
|
|
|
else: # for ocr_system
|
|
|
|
|
logger.info("draw boxes and texts!")
|
|
|
|
|
boxes = []
|
|
|
|
|
texts = []
|
|
|
|
|
scores = []
|
|
|
|
@ -63,7 +63,8 @@ def draw_server_result(image_file, res):
|
|
|
|
|
scores.append(res[dno]['confidence'])
|
|
|
|
|
boxes = np.array(boxes)
|
|
|
|
|
scores = np.array(scores)
|
|
|
|
|
draw_img = draw_ocr(image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
|
|
|
|
|
draw_img = draw_ocr(
|
|
|
|
|
image, boxes, texts, scores, draw_txt=True, drop_score=0.5)
|
|
|
|
|
return draw_img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -81,13 +82,13 @@ def main(url, image_path):
|
|
|
|
|
|
|
|
|
|
# 发送HTTP请求
|
|
|
|
|
starttime = time.time()
|
|
|
|
|
data = {'images':[cv2_to_base64(img)]}
|
|
|
|
|
data = {'images': [cv2_to_base64(img)]}
|
|
|
|
|
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
|
|
|
|
elapse = time.time() - starttime
|
|
|
|
|
total_time += elapse
|
|
|
|
|
print("Predict time of %s: %.3fs" % (image_file, elapse))
|
|
|
|
|
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
|
|
|
|
|
res = r.json()["results"][0]
|
|
|
|
|
print(res)
|
|
|
|
|
logger.info(res)
|
|
|
|
|
|
|
|
|
|
if is_visualize:
|
|
|
|
|
draw_img = draw_server_result(image_file, res)
|
|
|
|
@ -98,16 +99,17 @@ def main(url, image_path):
|
|
|
|
|
cv2.imwrite(
|
|
|
|
|
os.path.join(draw_img_save, os.path.basename(image_file)),
|
|
|
|
|
draw_img[:, :, ::-1])
|
|
|
|
|
print("The visualized image saved in {}".format(
|
|
|
|
|
logger.info("The visualized image saved in {}".format(
|
|
|
|
|
os.path.join(draw_img_save, os.path.basename(image_file))))
|
|
|
|
|
cnt += 1
|
|
|
|
|
if cnt % 100 == 0:
|
|
|
|
|
print(cnt, "processed")
|
|
|
|
|
print("avg time cost: ", float(total_time)/cnt)
|
|
|
|
|
logger.info("{} processed".format(cnt))
|
|
|
|
|
logger.info("avg time cost: {}".format(float(total_time) / cnt))
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
if len(sys.argv) != 3:
|
|
|
|
|
print("Usage: %s server_url image_path" % sys.argv[0])
|
|
|
|
|
logger.info("Usage: %s server_url image_path" % sys.argv[0])
|
|
|
|
|
else:
|
|
|
|
|
server_url = sys.argv[1]
|
|
|
|
|
image_path = sys.argv[2]
|
|
|
|
|