You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/model_zoo/official/cv/crnn/convert_ic03.py

135 lines
4.7 KiB

4 years ago
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
from xml.etree import ElementTree as ET
from PIL import Image
import numpy as np
def init_args():
parser = argparse.ArgumentParser('')
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
help='path to original images')
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
help='Directory where character dictionaries for the dataset were stored')
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
help='Directory where ord map dictionaries for the dataset were stored')
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
help='Directory where ord map dictionaries for the dataset were stored')
return parser.parse_args()
def xml_to_dict(xml_file, save_file=False):
tree = ET.parse(xml_file)
root = tree.getroot()
imgs_labels = []
for ch in root:
im_label = {}
for ch01 in ch:
if ch01.tag in 'taggedRectangles':
# multiple children
rect_list = []
for ch02 in ch01:
rect = {}
rect['location'] = ch02.attrib
rect['label'] = ch02[0].text
rect_list.append(rect)
im_label['rect'] = rect_list
else:
im_label[ch01.tag] = ch01.text
imgs_labels.append(im_label)
if save_file:
np.save("annotation_train.npy", imgs_labels)
return imgs_labels
def image_crop_save(image, location, output_dir):
"""
crop image with location (h,w,x,y)
save cropped image to output directory
"""
# avoid negative value of coordinates in annotation
start_x = np.maximum(location[2], 0)
end_x = start_x + location[1]
start_y = np.maximum(location[3], 0)
end_y = start_y + location[0]
print("image array shape :{}".format(image.shape))
print("crop region ", start_x, end_x, start_y, end_y)
if len(image.shape) == 3:
cropped = image[start_y:end_y, start_x:end_x, :]
else:
cropped = image[start_y:end_y, start_x:end_x]
im = Image.fromarray(np.uint8(cropped))
im.save(output_dir)
def convert():
args = init_args()
if not os.path.exists(args.dataset_dir):
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
if not os.path.exists(args.xml_file):
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
ims_labels_dict = xml_to_dict(args.xml_file, True)
num_images = len(ims_labels_dict)
annotation_list = []
print("Converting annotation, {} images in total ".format(num_images))
for i in range(num_images):
img_label = ims_labels_dict[i]
image_name = img_label['imageName']
rects = img_label['rect']
ext = image_name.split('.')[-1]
name = image_name[:-(len(ext)+1)]
fullpath = os.path.join(args.dataset_dir, image_name)
im_array = np.asarray(Image.open(fullpath))
print("processing image: {}".format(image_name))
for j, rect in enumerate(rects):
location = rect['location']
h = int(float(location['height']))
w = int(float(location['width']))
x = int(float(location['x']))
y = int(float(location['y']))
label = rect['label']
loc = [h, w, x, y]
output_name = name.replace("/", "_") + "_" + str(j) + "_" + label + '.' + ext
output_name = output_name.replace(",", "")
output_file = os.path.join(args.output_dir, output_name)
image_crop_save(im_array, loc, output_file)
ann = output_name + "," + label + ','
annotation_list.append(ann)
ann_file = args.output_annotation
with open(ann_file, 'w') as f:
for line in annotation_list:
txt = line + '\n'
f.write(txt)
if __name__ == "__main__":
convert()