You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
252 lines
9.6 KiB
252 lines
9.6 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""YOLOV4 dataset."""
|
|
import os
|
|
import multiprocessing
|
|
from PIL import Image
|
|
import cv2
|
|
from pycocotools.coco import COCO
|
|
|
|
import mindspore.dataset as de
|
|
import mindspore.dataset.vision.c_transforms as CV
|
|
|
|
from src.distributed_sampler import DistributedSampler
|
|
from src.transforms import reshape_fn, MultiScaleTrans
|
|
|
|
|
|
min_keypoints_per_image = 10
|
|
|
|
|
|
def _has_only_empty_bbox(anno):
|
|
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
|
|
|
|
|
|
def _count_visible_keypoints(anno):
|
|
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
|
|
|
|
|
|
def has_valid_annotation(anno):
|
|
"""Check annotation file."""
|
|
# if it's empty, there is no annotation
|
|
if not anno:
|
|
return False
|
|
# if all boxes have close to zero area, there is no annotation
|
|
if _has_only_empty_bbox(anno):
|
|
return False
|
|
# keypoints task have a slight different criteria for considering
|
|
# if an annotation is valid
|
|
if "keypoints" not in anno[0]:
|
|
return True
|
|
# for keypoint detection tasks, only consider valid images those
|
|
# containing at least min_keypoints_per_image
|
|
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
|
|
return True
|
|
return False
|
|
|
|
|
|
class COCOYoloDataset:
|
|
"""YOLOV4 Dataset for COCO."""
|
|
def __init__(self, root, ann_file, remove_images_without_annotations=True,
|
|
filter_crowd_anno=True, is_training=True):
|
|
self.coco = COCO(ann_file)
|
|
self.root = root
|
|
self.img_ids = list(sorted(self.coco.imgs.keys()))
|
|
self.filter_crowd_anno = filter_crowd_anno
|
|
self.is_training = is_training
|
|
|
|
# filter images without any annotations
|
|
if remove_images_without_annotations:
|
|
img_ids = []
|
|
for img_id in self.img_ids:
|
|
ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
|
|
anno = self.coco.loadAnns(ann_ids)
|
|
if has_valid_annotation(anno):
|
|
img_ids.append(img_id)
|
|
self.img_ids = img_ids
|
|
|
|
self.categories = {cat["id"]: cat["name"] for cat in self.coco.cats.values()}
|
|
|
|
self.cat_ids_to_continuous_ids = {
|
|
v: i for i, v in enumerate(self.coco.getCatIds())
|
|
}
|
|
self.continuous_ids_cat_ids = {
|
|
v: k for k, v in self.cat_ids_to_continuous_ids.items()
|
|
}
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
|
|
Returns:
|
|
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
|
|
generated by the image's annotation. img is a PIL image.
|
|
"""
|
|
coco = self.coco
|
|
img_id = self.img_ids[index]
|
|
img_path = coco.loadImgs(img_id)[0]["file_name"]
|
|
img = Image.open(os.path.join(self.root, img_path)).convert("RGB")
|
|
if not self.is_training:
|
|
return img, img_id
|
|
|
|
ann_ids = coco.getAnnIds(imgIds=img_id)
|
|
target = coco.loadAnns(ann_ids)
|
|
# filter crowd annotations
|
|
if self.filter_crowd_anno:
|
|
annos = [anno for anno in target if anno["iscrowd"] == 0]
|
|
else:
|
|
annos = [anno for anno in target]
|
|
|
|
target = {}
|
|
boxes = [anno["bbox"] for anno in annos]
|
|
target["bboxes"] = boxes
|
|
|
|
classes = [anno["category_id"] for anno in annos]
|
|
classes = [self.cat_ids_to_continuous_ids[cl] for cl in classes]
|
|
target["labels"] = classes
|
|
|
|
bboxes = target['bboxes']
|
|
labels = target['labels']
|
|
out_target = []
|
|
for bbox, label in zip(bboxes, labels):
|
|
tmp = []
|
|
# convert to [x_min y_min x_max y_max]
|
|
bbox = self._conve_top_down(bbox)
|
|
tmp.extend(bbox)
|
|
tmp.append(int(label))
|
|
# tmp [x_min y_min x_max y_max, label]
|
|
out_target.append(tmp)
|
|
return img, out_target, [], [], [], [], [], []
|
|
|
|
def __len__(self):
|
|
return len(self.img_ids)
|
|
|
|
def _conve_top_down(self, bbox):
|
|
x_min = bbox[0]
|
|
y_min = bbox[1]
|
|
w = bbox[2]
|
|
h = bbox[3]
|
|
return [x_min, y_min, x_min+w, y_min+h]
|
|
|
|
|
|
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
|
|
config=None, is_training=True, shuffle=True):
|
|
"""Create dataset for YOLOV4."""
|
|
cv2.setNumThreads(0)
|
|
|
|
if is_training:
|
|
filter_crowd = True
|
|
remove_empty_anno = True
|
|
else:
|
|
filter_crowd = False
|
|
remove_empty_anno = False
|
|
|
|
yolo_dataset = COCOYoloDataset(root=image_dir, ann_file=anno_path, filter_crowd_anno=filter_crowd,
|
|
remove_images_without_annotations=remove_empty_anno, is_training=is_training)
|
|
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
|
|
hwc_to_chw = CV.HWC2CHW()
|
|
|
|
config.dataset_size = len(yolo_dataset)
|
|
cores = multiprocessing.cpu_count()
|
|
num_parallel_workers = int(cores / device_num)
|
|
if is_training:
|
|
multi_scale_trans = MultiScaleTrans(config, device_num)
|
|
dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3",
|
|
"gt_box1", "gt_box2", "gt_box3"]
|
|
if device_num != 8:
|
|
ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names,
|
|
num_parallel_workers=min(32, num_parallel_workers),
|
|
sampler=distributed_sampler)
|
|
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
|
|
num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True)
|
|
else:
|
|
ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler)
|
|
ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names,
|
|
num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True)
|
|
else:
|
|
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
|
|
sampler=distributed_sampler)
|
|
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
|
|
ds = ds.map(operations=compose_map_func, input_columns=["image", "img_id"],
|
|
output_columns=["image", "image_shape", "img_id"],
|
|
column_order=["image", "image_shape", "img_id"],
|
|
num_parallel_workers=8)
|
|
ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8)
|
|
ds = ds.batch(batch_size, drop_remainder=True)
|
|
ds = ds.repeat(max_epoch)
|
|
|
|
return ds, len(yolo_dataset)
|
|
|
|
|
|
|
|
class COCOYoloDatasetv2():
|
|
"""
|
|
COCO yolo dataset definitation.
|
|
"""
|
|
def __init__(self, root, data_txt):
|
|
self.root = root
|
|
image_list = []
|
|
with open(data_txt, 'r') as f:
|
|
for line in f:
|
|
image_list.append(os.path.basename(line.strip()))
|
|
self.img_path = image_list
|
|
def __getitem__(self, index):
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
|
|
Returns:
|
|
(img, target) (tuple): target is a dictionary contains "bbox", "segmentation" or "keypoints",
|
|
generated by the image's annotation. img is a PIL image.
|
|
"""
|
|
img_path = self.img_path
|
|
img_id = self.img_path[index].replace('.jpg', '')
|
|
img = Image.open(os.path.join(self.root, img_path[index])).convert("RGB")
|
|
return img, int(img_id)
|
|
|
|
def __len__(self):
|
|
return len(self.img_path)
|
|
|
|
|
|
|
|
def create_yolo_datasetv2(image_dir,
|
|
data_txt,
|
|
batch_size,
|
|
max_epoch,
|
|
device_num,
|
|
rank,
|
|
config=None,
|
|
shuffle=True):
|
|
"""
|
|
Create yolo dataset.
|
|
"""
|
|
yolo_dataset = COCOYoloDatasetv2(root=image_dir, data_txt=data_txt)
|
|
distributed_sampler = DistributedSampler(len(yolo_dataset), device_num, rank, shuffle=shuffle)
|
|
hwc_to_chw = CV.HWC2CHW()
|
|
|
|
config.dataset_size = len(yolo_dataset)
|
|
|
|
ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"],
|
|
sampler=distributed_sampler)
|
|
compose_map_func = (lambda image, img_id: reshape_fn(image, img_id, config))
|
|
ds = ds.map(input_columns=["image", "img_id"],
|
|
output_columns=["image", "image_shape", "img_id"],
|
|
column_order=["image", "image_shape", "img_id"],
|
|
operations=compose_map_func, num_parallel_workers=8)
|
|
ds = ds.map(input_columns=["image"], operations=hwc_to_chw, num_parallel_workers=8)
|
|
ds = ds.batch(batch_size, drop_remainder=True)
|
|
ds = ds.repeat(max_epoch)
|
|
return ds, len(yolo_dataset)
|