|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
from collections import deque
|
|
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
from PIL import Image, ImageSequence
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
@ -23,7 +24,6 @@ from mindspore.dataset.vision.utils import Inter
|
|
|
|
|
from mindspore.communication.management import get_rank, get_group_size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_multipage_tiff(path):
|
|
|
|
|
"""Load tiff images containing many images in the channel dimension"""
|
|
|
|
|
return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))])
|
|
|
|
@ -164,3 +164,100 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
|
|
|
|
|
valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True)
|
|
|
|
|
|
|
|
|
|
return train_ds, valid_ds
|
|
|
|
|
|
|
|
|
|
class CellNucleiDataset:
|
|
|
|
|
"""
|
|
|
|
|
Cell nuclei dataset preprocess class.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, data_dir, repeat, is_train=False, split=0.8):
|
|
|
|
|
self.data_dir = data_dir
|
|
|
|
|
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
|
|
|
|
|
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
|
|
|
|
|
np.random.shuffle(self.train_ids)
|
|
|
|
|
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
|
|
|
|
|
self.is_train = is_train
|
|
|
|
|
self._preprocess_dataset()
|
|
|
|
|
|
|
|
|
|
def _preprocess_dataset(self):
|
|
|
|
|
for img_id in self.img_ids:
|
|
|
|
|
path = os.path.join(self.data_dir, img_id)
|
|
|
|
|
if (not os.path.exists(os.path.join(path, "image.png"))) or \
|
|
|
|
|
(not os.path.exists(os.path.join(path, "mask.png"))):
|
|
|
|
|
img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
|
|
|
|
|
if len(img.shape) == 2:
|
|
|
|
|
img = np.expand_dims(img, axis=-1)
|
|
|
|
|
img = np.concatenate([img, img, img], axis=-1)
|
|
|
|
|
mask = []
|
|
|
|
|
for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
|
|
|
|
|
mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
|
|
|
|
|
mask.append(mask_)
|
|
|
|
|
mask = np.max(mask, axis=0)
|
|
|
|
|
cv2.imwrite(os.path.join(path, "image.png"), img)
|
|
|
|
|
cv2.imwrite(os.path.join(path, "mask.png"), mask)
|
|
|
|
|
|
|
|
|
|
def _read_img_mask(self, img_id):
|
|
|
|
|
path = os.path.join(self.data_dir, img_id)
|
|
|
|
|
img = cv2.imread(os.path.join(path, "image.png"))
|
|
|
|
|
mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
|
|
|
|
|
return img, mask
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
if self.is_train:
|
|
|
|
|
return self._read_img_mask(self.train_ids[index])
|
|
|
|
|
return self._read_img_mask(self.val_ids[index])
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def column_names(self):
|
|
|
|
|
column_names = ['image', 'mask']
|
|
|
|
|
return column_names
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
if self.is_train:
|
|
|
|
|
return len(self.train_ids)
|
|
|
|
|
return len(self.val_ids)
|
|
|
|
|
|
|
|
|
|
def preprocess_img_mask(img, mask, img_size, augment=False):
|
|
|
|
|
"""
|
|
|
|
|
Preprocess for cell nuclei dataset.
|
|
|
|
|
Random crop and flip images and masks when augment is True.
|
|
|
|
|
"""
|
|
|
|
|
if augment:
|
|
|
|
|
img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1))
|
|
|
|
|
img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1))
|
|
|
|
|
img = cv2.resize(img, (img_size_w, img_size_h))
|
|
|
|
|
mask = cv2.resize(mask, (img_size_w, img_size_h))
|
|
|
|
|
dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1))
|
|
|
|
|
dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1))
|
|
|
|
|
img = img[dh:dh+img_size[1], dw:dw+img_size[0], :]
|
|
|
|
|
mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]]
|
|
|
|
|
if np.random.random() > 0.5:
|
|
|
|
|
flip_code = int(np.random.randint(-1, 2, 1))
|
|
|
|
|
img = cv2.flip(img, flip_code)
|
|
|
|
|
mask = cv2.flip(mask, flip_code)
|
|
|
|
|
else:
|
|
|
|
|
img = cv2.resize(img, img_size)
|
|
|
|
|
mask = cv2.resize(mask, img_size)
|
|
|
|
|
img = (img.astype(np.float32) - 127.5) / 127.5
|
|
|
|
|
img = img.transpose(2, 0, 1)
|
|
|
|
|
mask = mask.astype(np.float32) / 255
|
|
|
|
|
mask = (mask > 0.5).astype(np.int)
|
|
|
|
|
mask = (np.arange(2) == mask[..., None]).astype(int)
|
|
|
|
|
mask = mask.transpose(2, 0, 1).astype(np.float32)
|
|
|
|
|
return img, mask
|
|
|
|
|
|
|
|
|
|
def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False,
|
|
|
|
|
split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8):
|
|
|
|
|
"""
|
|
|
|
|
Get generator dataset for cell nuclei dataset.
|
|
|
|
|
"""
|
|
|
|
|
cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split)
|
|
|
|
|
sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train)
|
|
|
|
|
dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler)
|
|
|
|
|
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train))
|
|
|
|
|
dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names,
|
|
|
|
|
output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names,
|
|
|
|
|
python_multiprocessing=python_multiprocessing,
|
|
|
|
|
num_parallel_workers=num_parallel_workers)
|
|
|
|
|
dataset = dataset.batch(batch_size, drop_remainder=is_train)
|
|
|
|
|
dataset = dataset.repeat(1)
|
|
|
|
|
return dataset
|
|
|
|
|