!13879 change unet postprocess

From: @zhao_ting_v
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @c_34
pull/13879/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d38200f45f

@ -16,41 +16,30 @@
import os
import argparse
import logging
import cv2
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.ops.operations as F
from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss.loss import _Loss
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet
from scipy.special import softmax
from src.utils import UnetEval
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
class CrossEntropyWithLogits(_Loss):
class TempLoss(nn.Cell):
"""A temp loss cell."""
def __init__(self):
super(CrossEntropyWithLogits, self).__init__()
self.transpose_fn = F.Transpose()
self.reshape_fn = F.Reshape()
self.softmax_cross_entropy_loss = nn.SoftmaxCrossEntropyWithLogits()
self.cast = F.Cast()
super(TempLoss, self).__init__()
self.identity = F.identity()
def construct(self, logits, label):
# NCHW->NHWC
logits = self.transpose_fn(logits, (0, 2, 3, 1))
logits = self.cast(logits, mindspore.float32)
label = self.transpose_fn(label, (0, 2, 3, 1))
loss = self.reduce_mean(self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)),
self.reshape_fn(label, (-1, 2))))
return self.get_loss(loss)
return self.identity(logits)
class dice_coeff(nn.Metric):
@ -64,16 +53,35 @@ class dice_coeff(nn.Metric):
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
y = self._convert_data(inputs[1])
self._samples_num += y.shape[0]
y_pred = y_pred.transpose(0, 2, 3, 1)
y = y.transpose(0, 2, 3, 1)
y_pred = softmax(y_pred, axis=3)
b, h, w, c = y.shape
if b != 1:
raise ValueError('Batch size should be 1 when in evaluation.')
y = y.reshape((h, w, c))
if cfg_unet["eval_activate"].lower() == "softmax":
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
if cfg_unet["eval_resize"]:
y_pred = []
for i in range(cfg_unet["num_classes"]):
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
y_pred = np.stack(y_pred, axis=-1)
else:
y_pred = y_softmax
elif cfg_unet["eval_activate"].lower() == "argmax":
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
y_pred = []
for i in range(cfg_unet["num_classes"]):
if cfg_unet["eval_resize"]:
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
else:
y_pred.append(np.float32(y_argmax == i))
y_pred = np.stack(y_pred, axis=-1)
else:
raise ValueError('config eval_activate should be softmax or argmax.')
y_pred = y_pred.astype(np.float32)
inter = np.dot(y_pred.flatten(), y.flatten())
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
@ -104,14 +112,14 @@ def test_net(data_dir,
raise ValueError("Unsupported model: {}".format(cfg['model']))
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
criterion = CrossEntropyWithLogits()
net = UnetEval(net)
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8)
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
eval_resize=cfg["eval_resize"], split=0.8)
else:
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size'])
model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()})
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()})
print("============== Starting Evaluating ============")
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]

@ -33,7 +33,9 @@ cfg_unet_medical = {
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['outc.weight', 'outc.bias']
'filter_weight': ['outc.weight', 'outc.bias'],
'eval_activate': 'Softmax',
'eval_resize': False
}
cfg_unet_nested = {
@ -59,7 +61,9 @@ cfg_unet_nested = {
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'],
'eval_activate': 'Softmax',
'eval_resize': False
}
cfg_unet_nested_cell = {
@ -86,7 +90,9 @@ cfg_unet_nested_cell = {
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'],
'eval_activate': 'Softmax',
'eval_resize': False
}
cfg_unet_simple = {
@ -109,7 +115,12 @@ cfg_unet_simple = {
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ["final.weight"]
'filter_weight': ["final.weight"],
'eval_activate': 'Softmax',
'eval_resize': False
}
cfg_unet = cfg_unet_medical
if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']:
print("ISBI dataset not support resize to original image size when in evaluation.")
cfg_unet['eval_resize'] = False

@ -216,7 +216,7 @@ class CellNucleiDataset:
return len(self.train_ids)
return len(self.val_ids)
def preprocess_img_mask(img, mask, img_size, augment=False):
def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False):
"""
Preprocess for cell nuclei dataset.
Random crop and flip images and masks when augment is True.
@ -236,7 +236,8 @@ def preprocess_img_mask(img, mask, img_size, augment=False):
mask = cv2.flip(mask, flip_code)
else:
img = cv2.resize(img, img_size)
mask = cv2.resize(mask, img_size)
if not eval_resize:
mask = cv2.resize(mask, img_size)
img = (img.astype(np.float32) - 127.5) / 127.5
img = img.transpose(2, 0, 1)
mask = mask.astype(np.float32) / 255
@ -245,7 +246,7 @@ def preprocess_img_mask(img, mask, img_size, augment=False):
mask = mask.transpose(2, 0, 1).astype(np.float32)
return img, mask
def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False,
def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, eval_resize=False,
split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8):
"""
Get generator dataset for cell nuclei dataset.
@ -253,7 +254,8 @@ def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=
cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split)
sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train)
dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler)
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train))
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train,
eval_resize))
dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names,
output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names,
python_multiprocessing=python_multiprocessing,

@ -16,9 +16,29 @@
import time
import numpy as np
from PIL import Image
from mindspore import nn
from mindspore.ops import operations as ops
from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor
class UnetEval(nn.Cell):
"""
Add Unet evaluation activation.
"""
def __init__(self, net):
super(UnetEval, self).__init__()
self.net = net
self.transpose = ops.Transpose()
self.softmax = ops.Softmax(axis=-1)
self.argmax = ops.Argmax(axis=-1)
def construct(self, x):
out = self.net(x)
out = self.transpose(out, (0, 2, 3, 1))
softmax_out = self.softmax(out)
argmax_out = self.argmax(out)
return (softmax_out, argmax_out)
class StepLossTimeMonitor(Callback):
def __init__(self, batch_size, per_print_times=1):

Loading…
Cancel
Save