!13879 change unet postprocess

From: @zhao_ting_v
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @c_34
pull/13879/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit d38200f45f

@ -16,41 +16,30 @@
import os import os
import argparse import argparse
import logging import logging
import cv2
import numpy as np import numpy as np
import mindspore
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.operations as F import mindspore.ops.operations as F
from mindspore import context, Model from mindspore import context, Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss.loss import _Loss
from src.data_loader import create_dataset, create_cell_nuclei_dataset from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.unet_medical import UNetMedical from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet from src.unet_nested import NestedUNet, UNet
from src.config import cfg_unet from src.config import cfg_unet
from src.utils import UnetEval
from scipy.special import softmax
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
class CrossEntropyWithLogits(_Loss): class TempLoss(nn.Cell):
"""A temp loss cell."""
def __init__(self): def __init__(self):
super(CrossEntropyWithLogits, self).__init__() super(TempLoss, self).__init__()
self.transpose_fn = F.Transpose() self.identity = F.identity()
self.reshape_fn = F.Reshape()
self.softmax_cross_entropy_loss = nn.SoftmaxCrossEntropyWithLogits()
self.cast = F.Cast()
def construct(self, logits, label): def construct(self, logits, label):
# NCHW->NHWC return self.identity(logits)
logits = self.transpose_fn(logits, (0, 2, 3, 1))
logits = self.cast(logits, mindspore.float32)
label = self.transpose_fn(label, (0, 2, 3, 1))
loss = self.reduce_mean(self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)),
self.reshape_fn(label, (-1, 2))))
return self.get_loss(loss)
class dice_coeff(nn.Metric): class dice_coeff(nn.Metric):
@ -64,16 +53,35 @@ class dice_coeff(nn.Metric):
def update(self, *inputs): def update(self, *inputs):
if len(inputs) != 2: if len(inputs) != 2:
raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1]) y = self._convert_data(inputs[1])
self._samples_num += y.shape[0] self._samples_num += y.shape[0]
y_pred = y_pred.transpose(0, 2, 3, 1)
y = y.transpose(0, 2, 3, 1) y = y.transpose(0, 2, 3, 1)
y_pred = softmax(y_pred, axis=3) b, h, w, c = y.shape
if b != 1:
raise ValueError('Batch size should be 1 when in evaluation.')
y = y.reshape((h, w, c))
if cfg_unet["eval_activate"].lower() == "softmax":
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
if cfg_unet["eval_resize"]:
y_pred = []
for i in range(cfg_unet["num_classes"]):
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
y_pred = np.stack(y_pred, axis=-1)
else:
y_pred = y_softmax
elif cfg_unet["eval_activate"].lower() == "argmax":
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
y_pred = []
for i in range(cfg_unet["num_classes"]):
if cfg_unet["eval_resize"]:
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
else:
y_pred.append(np.float32(y_argmax == i))
y_pred = np.stack(y_pred, axis=-1)
else:
raise ValueError('config eval_activate should be softmax or argmax.')
y_pred = y_pred.astype(np.float32)
inter = np.dot(y_pred.flatten(), y.flatten()) inter = np.dot(y_pred.flatten(), y.flatten())
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
@ -104,14 +112,14 @@ def test_net(data_dir,
raise ValueError("Unsupported model: {}".format(cfg['model'])) raise ValueError("Unsupported model: {}".format(cfg['model']))
param_dict = load_checkpoint(ckpt_path) param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
net = UnetEval(net)
criterion = CrossEntropyWithLogits()
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8) valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
eval_resize=cfg["eval_resize"], split=0.8)
else: else:
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
do_crop=cfg['crop'], img_size=cfg['img_size']) do_crop=cfg['crop'], img_size=cfg['img_size'])
model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()})
print("============== Starting Evaluating ============") print("============== Starting Evaluating ============")
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]

@ -33,7 +33,9 @@ cfg_unet_medical = {
'resume': False, 'resume': False,
'resume_ckpt': './', 'resume_ckpt': './',
'transfer_training': False, 'transfer_training': False,
'filter_weight': ['outc.weight', 'outc.bias'] 'filter_weight': ['outc.weight', 'outc.bias'],
'eval_activate': 'Softmax',
'eval_resize': False
} }
cfg_unet_nested = { cfg_unet_nested = {
@ -59,7 +61,9 @@ cfg_unet_nested = {
'resume': False, 'resume': False,
'resume_ckpt': './', 'resume_ckpt': './',
'transfer_training': False, 'transfer_training': False,
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'],
'eval_activate': 'Softmax',
'eval_resize': False
} }
cfg_unet_nested_cell = { cfg_unet_nested_cell = {
@ -86,7 +90,9 @@ cfg_unet_nested_cell = {
'resume': False, 'resume': False,
'resume_ckpt': './', 'resume_ckpt': './',
'transfer_training': False, 'transfer_training': False,
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'],
'eval_activate': 'Softmax',
'eval_resize': False
} }
cfg_unet_simple = { cfg_unet_simple = {
@ -109,7 +115,12 @@ cfg_unet_simple = {
'resume': False, 'resume': False,
'resume_ckpt': './', 'resume_ckpt': './',
'transfer_training': False, 'transfer_training': False,
'filter_weight': ["final.weight"] 'filter_weight': ["final.weight"],
'eval_activate': 'Softmax',
'eval_resize': False
} }
cfg_unet = cfg_unet_medical cfg_unet = cfg_unet_medical
if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']:
print("ISBI dataset not support resize to original image size when in evaluation.")
cfg_unet['eval_resize'] = False

@ -216,7 +216,7 @@ class CellNucleiDataset:
return len(self.train_ids) return len(self.train_ids)
return len(self.val_ids) return len(self.val_ids)
def preprocess_img_mask(img, mask, img_size, augment=False): def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False):
""" """
Preprocess for cell nuclei dataset. Preprocess for cell nuclei dataset.
Random crop and flip images and masks when augment is True. Random crop and flip images and masks when augment is True.
@ -236,7 +236,8 @@ def preprocess_img_mask(img, mask, img_size, augment=False):
mask = cv2.flip(mask, flip_code) mask = cv2.flip(mask, flip_code)
else: else:
img = cv2.resize(img, img_size) img = cv2.resize(img, img_size)
mask = cv2.resize(mask, img_size) if not eval_resize:
mask = cv2.resize(mask, img_size)
img = (img.astype(np.float32) - 127.5) / 127.5 img = (img.astype(np.float32) - 127.5) / 127.5
img = img.transpose(2, 0, 1) img = img.transpose(2, 0, 1)
mask = mask.astype(np.float32) / 255 mask = mask.astype(np.float32) / 255
@ -245,7 +246,7 @@ def preprocess_img_mask(img, mask, img_size, augment=False):
mask = mask.transpose(2, 0, 1).astype(np.float32) mask = mask.transpose(2, 0, 1).astype(np.float32)
return img, mask return img, mask
def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, eval_resize=False,
split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8): split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8):
""" """
Get generator dataset for cell nuclei dataset. Get generator dataset for cell nuclei dataset.
@ -253,7 +254,8 @@ def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=
cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split) cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split)
sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train) sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train)
dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler) dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler)
compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train)) compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train,
eval_resize))
dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names, dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names,
output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names, output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names,
python_multiprocessing=python_multiprocessing, python_multiprocessing=python_multiprocessing,

@ -16,9 +16,29 @@
import time import time
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from mindspore import nn
from mindspore.ops import operations as ops
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.common.tensor import Tensor from mindspore.common.tensor import Tensor
class UnetEval(nn.Cell):
"""
Add Unet evaluation activation.
"""
def __init__(self, net):
super(UnetEval, self).__init__()
self.net = net
self.transpose = ops.Transpose()
self.softmax = ops.Softmax(axis=-1)
self.argmax = ops.Argmax(axis=-1)
def construct(self, x):
out = self.net(x)
out = self.transpose(out, (0, 2, 3, 1))
softmax_out = self.softmax(out)
argmax_out = self.argmax(out)
return (softmax_out, argmax_out)
class StepLossTimeMonitor(Callback): class StepLossTimeMonitor(Callback):
def __init__(self, batch_size, per_print_times=1): def __init__(self, batch_size, per_print_times=1):

Loading…
Cancel
Save