add ceil nuclei dataset for unet++

pull/12878/head
zhaoting 4 years ago
parent 891abb1eb8
commit 8710c953ca

@ -24,7 +24,7 @@ from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss.loss import _Loss
from src.data_loader import create_dataset
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet
@ -59,6 +59,7 @@ class dice_coeff(nn.Metric):
self.clear()
def clear(self):
self._dice_coeff_sum = 0
self._iou_sum = 0
self._samples_num = 0
def update(self, *inputs):
@ -77,13 +78,15 @@ class dice_coeff(nn.Metric):
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
single_dice_coeff = 2*float(inter)/float(union+1e-6)
print("single dice coeff is:", single_dice_coeff)
single_iou = single_dice_coeff / (2 - single_dice_coeff)
print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou))
self._dice_coeff_sum += single_dice_coeff
self._iou_sum += single_iou
def eval(self):
if self._samples_num == 0:
raise RuntimeError('Total samples num must not be 0.')
return self._dice_coeff_sum / float(self._samples_num)
return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num))
def test_net(data_dir,
@ -93,7 +96,8 @@ def test_net(data_dir,
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
use_bn=cfg['use_bn'], use_ds=False)
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
@ -102,13 +106,17 @@ def test_net(data_dir,
load_param_into_net(net, param_dict)
criterion = CrossEntropyWithLogits()
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8)
else:
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])
model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()})
print("============== Starting Evaluating ============")
dice_score = model.eval(valid_dataset, dataset_sink_mode=False)
print("============== Cross valid dice coeff is:", dice_score)
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]
print("============== Cross valid dice coeff is:", eval_score[0])
print("============== Cross valid IOU is:", eval_score[1])
def get_args():

@ -42,7 +42,8 @@ if __name__ == "__main__":
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
use_bn=cfg['use_bn'], use_ds=False)
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:

@ -50,6 +50,34 @@ cfg_unet_nested = {
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'use_bn': True,
'use_ds': True,
'use_deconv': True,
'resume': False,
'resume_ckpt': './',
}
cfg_unet_nested_cell = {
'model': 'unet_nested',
'dataset': 'Cell_nuclei',
'crop': None,
'img_size': [96, 96],
'lr': 3e-4,
'epochs': 200,
'distribute_epochs': 1600,
'batchsize': 16,
'cross_valid_ind': 1,
'num_classes': 2,
'num_channels': 3,
'keep_checkpoint_max': 10,
'weight_decay': 0.0005,
'loss_scale': 1024.0,
'FixedLossScaleManager': 1024.0,
'use_bn': True,
'use_ds': True,
'use_deconv': True,
'resume': False,
'resume_ckpt': './',

@ -15,6 +15,7 @@
import os
from collections import deque
import cv2
import numpy as np
from PIL import Image, ImageSequence
import mindspore.dataset as ds
@ -23,7 +24,6 @@ from mindspore.dataset.vision.utils import Inter
from mindspore.communication.management import get_rank, get_group_size
def _load_multipage_tiff(path):
"""Load tiff images containing many images in the channel dimension"""
return np.array([np.array(p) for p in ImageSequence.Iterator(Image.open(path))])
@ -164,3 +164,100 @@ def create_dataset(data_dir, repeat=400, train_batch_size=16, augment=False, cro
valid_ds = valid_ds.batch(batch_size=1, drop_remainder=True)
return train_ds, valid_ds
class CellNucleiDataset:
"""
Cell nuclei dataset preprocess class.
"""
def __init__(self, data_dir, repeat, is_train=False, split=0.8):
self.data_dir = data_dir
self.img_ids = sorted(next(os.walk(self.data_dir))[1])
self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat
np.random.shuffle(self.train_ids)
self.val_ids = self.img_ids[int(len(self.img_ids) * split):]
self.is_train = is_train
self._preprocess_dataset()
def _preprocess_dataset(self):
for img_id in self.img_ids:
path = os.path.join(self.data_dir, img_id)
if (not os.path.exists(os.path.join(path, "image.png"))) or \
(not os.path.exists(os.path.join(path, "mask.png"))):
img = cv2.imread(os.path.join(path, "images", img_id + ".png"))
if len(img.shape) == 2:
img = np.expand_dims(img, axis=-1)
img = np.concatenate([img, img, img], axis=-1)
mask = []
for mask_file in next(os.walk(os.path.join(path, "masks")))[2]:
mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE)
mask.append(mask_)
mask = np.max(mask, axis=0)
cv2.imwrite(os.path.join(path, "image.png"), img)
cv2.imwrite(os.path.join(path, "mask.png"), mask)
def _read_img_mask(self, img_id):
path = os.path.join(self.data_dir, img_id)
img = cv2.imread(os.path.join(path, "image.png"))
mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE)
return img, mask
def __getitem__(self, index):
if self.is_train:
return self._read_img_mask(self.train_ids[index])
return self._read_img_mask(self.val_ids[index])
@property
def column_names(self):
column_names = ['image', 'mask']
return column_names
def __len__(self):
if self.is_train:
return len(self.train_ids)
return len(self.val_ids)
def preprocess_img_mask(img, mask, img_size, augment=False):
"""
Preprocess for cell nuclei dataset.
Random crop and flip images and masks when augment is True.
"""
if augment:
img_size_w = int(np.random.randint(img_size[0], img_size[0] * 1.5, 1))
img_size_h = int(np.random.randint(img_size[1], img_size[1] * 1.5, 1))
img = cv2.resize(img, (img_size_w, img_size_h))
mask = cv2.resize(mask, (img_size_w, img_size_h))
dw = int(np.random.randint(0, img_size_w - img_size[0] + 1, 1))
dh = int(np.random.randint(0, img_size_h - img_size[1] + 1, 1))
img = img[dh:dh+img_size[1], dw:dw+img_size[0], :]
mask = mask[dh:dh+img_size[1], dw:dw+img_size[0]]
if np.random.random() > 0.5:
flip_code = int(np.random.randint(-1, 2, 1))
img = cv2.flip(img, flip_code)
mask = cv2.flip(mask, flip_code)
else:
img = cv2.resize(img, img_size)
mask = cv2.resize(mask, img_size)
img = (img.astype(np.float32) - 127.5) / 127.5
img = img.transpose(2, 0, 1)
mask = mask.astype(np.float32) / 255
mask = (mask > 0.5).astype(np.int)
mask = (np.arange(2) == mask[..., None]).astype(int)
mask = mask.transpose(2, 0, 1).astype(np.float32)
return img, mask
def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False,
split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8):
"""
Get generator dataset for cell nuclei dataset.
"""
cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split)
sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train)
dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler)
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train))
dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names,
output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names,
python_multiprocessing=python_multiprocessing,
num_parallel_workers=num_parallel_workers)
dataset = dataset.batch(batch_size, drop_remainder=is_train)
dataset = dataset.repeat(1)
return dataset

@ -36,3 +36,15 @@ class CrossEntropyWithLogits(_Loss):
loss = self.reduce_mean(
self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), self.reshape_fn(label, (-1, 2))))
return self.get_loss(loss)
class MultiCrossEntropyWithLogits(nn.Cell):
def __init__(self):
super(MultiCrossEntropyWithLogits, self).__init__()
self.loss = CrossEntropyWithLogits()
self.squeeze = F.Squeeze()
def construct(self, logits, label):
total_loss = 0
for i in range(len(logits)):
total_loss += self.loss(self.squeeze(logits[i:i+1]), label)
return total_loss

@ -16,6 +16,7 @@
# Model of UnetPlusPlus
import mindspore.nn as nn
import mindspore.ops as P
from .unet_parts import UnetConv2d, UnetUp
@ -63,6 +64,7 @@ class NestedUNet(nn.Cell):
self.final2 = nn.Conv2d(filters[0], n_class, 1)
self.final3 = nn.Conv2d(filters[0], n_class, 1)
self.final4 = nn.Conv2d(filters[0], n_class, 1)
self.stack = P.Stack(axis=0)
def construct(self, inputs):
x00 = self.conv00(inputs) # channel = filters[0]
@ -86,13 +88,12 @@ class NestedUNet(nn.Cell):
x04 = self.up_concat04(x13, x00, x01, x02, x03) # channel = filters[0]
final1 = self.final1(x01)
final2 = self.final1(x02)
final3 = self.final1(x03)
final4 = self.final1(x04)
final = (final1 + final2 + final3 + final4) / 4.0
final2 = self.final2(x02)
final3 = self.final3(x03)
final4 = self.final4(x04)
if self.use_ds:
final = self.stack((final1, final2, final3, final4))
return final
return final4

@ -51,9 +51,23 @@ class StepLossTimeMonitor(Callback):
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch))
self.losses.append(loss)
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
# TEST
print("step: %s, loss is %s, fps is %s" % (cur_step_in_epoch, loss, step_fps), flush=True)
def epoch_begin(self, run_context):
self.epoch_start = time.time()
self.losses = []
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_cost = time.time() - self.epoch_start
step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
step_fps = self.batch_size * 1.0 * step_in_epoch / epoch_cost
print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format(
cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True)
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))

@ -21,15 +21,15 @@ import ast
import mindspore
import mindspore.nn as nn
from mindspore import Model, context
from mindspore.communication.management import init, get_group_size
from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.data_loader import create_dataset
from src.loss import CrossEntropyWithLogits
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
from src.utils import StepLossTimeMonitor
from src.config import cfg_unet
@ -46,10 +46,12 @@ def train_net(data_dir,
run_distribute=False,
cfg=None):
rank = 0
group_size = 1
if run_distribute:
init()
group_size = get_group_size()
rank = get_rank()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=group_size,
@ -58,7 +60,8 @@ def train_net(data_dir,
if cfg['model'] == 'unet_medical':
net = UNetMedical(n_channels=cfg['num_channels'], n_classes=cfg['num_classes'])
elif cfg['model'] == 'unet_nested':
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
net = NestedUNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'], use_deconv=cfg['use_deconv'],
use_bn=cfg['use_bn'], use_ds=cfg['use_ds'])
elif cfg['model'] == 'unet_simple':
net = UNet(in_channel=cfg['num_channels'], n_class=cfg['num_classes'])
else:
@ -68,14 +71,28 @@ def train_net(data_dir,
param_dict = load_checkpoint(cfg['resume_ckpt'])
load_param_into_net(net, param_dict)
criterion = CrossEntropyWithLogits()
train_dataset, _ = create_dataset(data_dir, epochs, batch_size, True, cross_valid_ind, run_distribute, cfg["crop"],
cfg['img_size'])
if 'use_ds' in cfg and cfg['use_ds']:
criterion = MultiCrossEntropyWithLogits()
else:
criterion = CrossEntropyWithLogits()
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
repeat = 10
dataset_sink_mode = True
per_print_times = 0
train_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], repeat, batch_size,
is_train=True, augment=True, split=0.8, rank=rank,
group_size=group_size)
else:
repeat = epochs
dataset_sink_mode = False
per_print_times = 1
train_dataset, _ = create_dataset(data_dir, repeat, batch_size, True, cross_valid_ind, run_distribute,
cfg["crop"], cfg['img_size'])
train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size)
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
keep_checkpoint_max=cfg['keep_checkpoint_max'])
ckpoint_cb = ModelCheckpoint(prefix='ckpt_unet_medical_adam',
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(cfg['model']),
directory='./ckpt_{}/'.format(device_id),
config=ckpt_config)
@ -87,13 +104,11 @@ def train_net(data_dir,
model = Model(net, loss_fn=criterion, loss_scale_manager=loss_scale_manager, optimizer=optimizer, amp_level="O3")
print("============== Starting Training ==============")
model.train(1, train_dataset, callbacks=[StepLossTimeMonitor(batch_size=batch_size), ckpoint_cb],
dataset_sink_mode=False)
callbacks = [StepLossTimeMonitor(batch_size=batch_size, per_print_times=per_print_times), ckpoint_cb]
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
print("============== End Training ==============")
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

Loading…
Cancel
Save