From 3ccab4efa416266a1f8866d0800906cf15b01f00 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Tue, 23 Mar 2021 11:10:55 +0800 Subject: [PATCH] change unet postprocess --- model_zoo/official/cv/unet/eval.py | 66 +++++++++++-------- model_zoo/official/cv/unet/src/config.py | 19 ++++-- model_zoo/official/cv/unet/src/data_loader.py | 10 +-- model_zoo/official/cv/unet/src/utils.py | 20 ++++++ 4 files changed, 78 insertions(+), 37 deletions(-) diff --git a/model_zoo/official/cv/unet/eval.py b/model_zoo/official/cv/unet/eval.py index 9f23ff9273..a7fad58e28 100644 --- a/model_zoo/official/cv/unet/eval.py +++ b/model_zoo/official/cv/unet/eval.py @@ -16,41 +16,30 @@ import os import argparse import logging +import cv2 import numpy as np -import mindspore import mindspore.nn as nn import mindspore.ops.operations as F from mindspore import context, Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.nn.loss.loss import _Loss from src.data_loader import create_dataset, create_cell_nuclei_dataset from src.unet_medical import UNetMedical from src.unet_nested import NestedUNet, UNet from src.config import cfg_unet - -from scipy.special import softmax +from src.utils import UnetEval device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) -class CrossEntropyWithLogits(_Loss): +class TempLoss(nn.Cell): + """A temp loss cell.""" def __init__(self): - super(CrossEntropyWithLogits, self).__init__() - self.transpose_fn = F.Transpose() - self.reshape_fn = F.Reshape() - self.softmax_cross_entropy_loss = nn.SoftmaxCrossEntropyWithLogits() - self.cast = F.Cast() + super(TempLoss, self).__init__() + self.identity = F.identity() def construct(self, logits, label): - # NCHW->NHWC - logits = self.transpose_fn(logits, (0, 2, 3, 1)) - logits = self.cast(logits, mindspore.float32) - label = self.transpose_fn(label, (0, 2, 3, 1)) - - loss = self.reduce_mean(self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)), - self.reshape_fn(label, (-1, 2)))) - return self.get_loss(loss) + return self.identity(logits) class dice_coeff(nn.Metric): @@ -64,16 +53,35 @@ class dice_coeff(nn.Metric): def update(self, *inputs): if len(inputs) != 2: - raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) - - y_pred = self._convert_data(inputs[0]) + raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) y = self._convert_data(inputs[1]) - self._samples_num += y.shape[0] - y_pred = y_pred.transpose(0, 2, 3, 1) y = y.transpose(0, 2, 3, 1) - y_pred = softmax(y_pred, axis=3) - + b, h, w, c = y.shape + if b != 1: + raise ValueError('Batch size should be 1 when in evaluation.') + y = y.reshape((h, w, c)) + if cfg_unet["eval_activate"].lower() == "softmax": + y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0) + if cfg_unet["eval_resize"]: + y_pred = [] + for i in range(cfg_unet["num_classes"]): + y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255) + y_pred = np.stack(y_pred, axis=-1) + else: + y_pred = y_softmax + elif cfg_unet["eval_activate"].lower() == "argmax": + y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0) + y_pred = [] + for i in range(cfg_unet["num_classes"]): + if cfg_unet["eval_resize"]: + y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST)) + else: + y_pred.append(np.float32(y_argmax == i)) + y_pred = np.stack(y_pred, axis=-1) + else: + raise ValueError('config eval_activate should be softmax or argmax.') + y_pred = y_pred.astype(np.float32) inter = np.dot(y_pred.flatten(), y.flatten()) union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) @@ -104,14 +112,14 @@ def test_net(data_dir, raise ValueError("Unsupported model: {}".format(cfg['model'])) param_dict = load_checkpoint(ckpt_path) load_param_into_net(net, param_dict) - - criterion = CrossEntropyWithLogits() + net = UnetEval(net) if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": - valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8) + valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, + eval_resize=cfg["eval_resize"], split=0.8) else: _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], img_size=cfg['img_size']) - model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()}) + model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()}) print("============== Starting Evaluating ============") eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"] diff --git a/model_zoo/official/cv/unet/src/config.py b/model_zoo/official/cv/unet/src/config.py index 33bfe3bce1..881de673f1 100644 --- a/model_zoo/official/cv/unet/src/config.py +++ b/model_zoo/official/cv/unet/src/config.py @@ -33,7 +33,9 @@ cfg_unet_medical = { 'resume': False, 'resume_ckpt': './', 'transfer_training': False, - 'filter_weight': ['outc.weight', 'outc.bias'] + 'filter_weight': ['outc.weight', 'outc.bias'], + 'eval_activate': 'Softmax', + 'eval_resize': False } cfg_unet_nested = { @@ -59,7 +61,9 @@ cfg_unet_nested = { 'resume': False, 'resume_ckpt': './', 'transfer_training': False, - 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] + 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'], + 'eval_activate': 'Softmax', + 'eval_resize': False } cfg_unet_nested_cell = { @@ -86,7 +90,9 @@ cfg_unet_nested_cell = { 'resume': False, 'resume_ckpt': './', 'transfer_training': False, - 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'] + 'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight'], + 'eval_activate': 'Softmax', + 'eval_resize': False } cfg_unet_simple = { @@ -109,7 +115,12 @@ cfg_unet_simple = { 'resume': False, 'resume_ckpt': './', 'transfer_training': False, - 'filter_weight': ["final.weight"] + 'filter_weight': ["final.weight"], + 'eval_activate': 'Softmax', + 'eval_resize': False } cfg_unet = cfg_unet_medical +if not ('dataset' in cfg_unet and cfg_unet['dataset'] == 'Cell_nuclei') and cfg_unet['eval_resize']: + print("ISBI dataset not support resize to original image size when in evaluation.") + cfg_unet['eval_resize'] = False diff --git a/model_zoo/official/cv/unet/src/data_loader.py b/model_zoo/official/cv/unet/src/data_loader.py index 526ca14ed5..3dee1493b7 100644 --- a/model_zoo/official/cv/unet/src/data_loader.py +++ b/model_zoo/official/cv/unet/src/data_loader.py @@ -216,7 +216,7 @@ class CellNucleiDataset: return len(self.train_ids) return len(self.val_ids) -def preprocess_img_mask(img, mask, img_size, augment=False): +def preprocess_img_mask(img, mask, img_size, augment=False, eval_resize=False): """ Preprocess for cell nuclei dataset. Random crop and flip images and masks when augment is True. @@ -236,7 +236,8 @@ def preprocess_img_mask(img, mask, img_size, augment=False): mask = cv2.flip(mask, flip_code) else: img = cv2.resize(img, img_size) - mask = cv2.resize(mask, img_size) + if not eval_resize: + mask = cv2.resize(mask, img_size) img = (img.astype(np.float32) - 127.5) / 127.5 img = img.transpose(2, 0, 1) mask = mask.astype(np.float32) / 255 @@ -245,7 +246,7 @@ def preprocess_img_mask(img, mask, img_size, augment=False): mask = mask.transpose(2, 0, 1).astype(np.float32) return img, mask -def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, +def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train=False, augment=False, eval_resize=False, split=0.8, rank=0, group_size=1, python_multiprocessing=True, num_parallel_workers=8): """ Get generator dataset for cell nuclei dataset. @@ -253,7 +254,8 @@ def create_cell_nuclei_dataset(data_dir, img_size, repeat, batch_size, is_train= cell_dataset = CellNucleiDataset(data_dir, repeat, is_train, split) sampler = ds.DistributedSampler(group_size, rank, shuffle=is_train) dataset = ds.GeneratorDataset(cell_dataset, cell_dataset.column_names, sampler=sampler) - compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train)) + compose_map_func = (lambda image, mask: preprocess_img_mask(image, mask, tuple(img_size), augment and is_train, + eval_resize)) dataset = dataset.map(operations=compose_map_func, input_columns=cell_dataset.column_names, output_columns=cell_dataset.column_names, column_order=cell_dataset.column_names, python_multiprocessing=python_multiprocessing, diff --git a/model_zoo/official/cv/unet/src/utils.py b/model_zoo/official/cv/unet/src/utils.py index b9f2846296..5285ee5c83 100644 --- a/model_zoo/official/cv/unet/src/utils.py +++ b/model_zoo/official/cv/unet/src/utils.py @@ -16,9 +16,29 @@ import time import numpy as np from PIL import Image +from mindspore import nn +from mindspore.ops import operations as ops from mindspore.train.callback import Callback from mindspore.common.tensor import Tensor +class UnetEval(nn.Cell): + """ + Add Unet evaluation activation. + """ + def __init__(self, net): + super(UnetEval, self).__init__() + self.net = net + self.transpose = ops.Transpose() + self.softmax = ops.Softmax(axis=-1) + self.argmax = ops.Argmax(axis=-1) + + def construct(self, x): + out = self.net(x) + out = self.transpose(out, (0, 2, 3, 1)) + softmax_out = self.softmax(out) + argmax_out = self.argmax(out) + return (softmax_out, argmax_out) + class StepLossTimeMonitor(Callback): def __init__(self, batch_size, per_print_times=1):