|
|
@ -16,41 +16,30 @@
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
|
|
|
|
import cv2
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import mindspore
|
|
|
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.ops.operations as F
|
|
|
|
import mindspore.ops.operations as F
|
|
|
|
from mindspore import context, Model
|
|
|
|
from mindspore import context, Model
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
from mindspore.nn.loss.loss import _Loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from src.data_loader import create_dataset, create_cell_nuclei_dataset
|
|
|
|
from src.data_loader import create_dataset, create_cell_nuclei_dataset
|
|
|
|
from src.unet_medical import UNetMedical
|
|
|
|
from src.unet_medical import UNetMedical
|
|
|
|
from src.unet_nested import NestedUNet, UNet
|
|
|
|
from src.unet_nested import NestedUNet, UNet
|
|
|
|
from src.config import cfg_unet
|
|
|
|
from src.config import cfg_unet
|
|
|
|
|
|
|
|
from src.utils import UnetEval
|
|
|
|
from scipy.special import softmax
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
device_id = int(os.getenv('DEVICE_ID'))
|
|
|
|
device_id = int(os.getenv('DEVICE_ID'))
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
|
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossEntropyWithLogits(_Loss):
|
|
|
|
class TempLoss(nn.Cell):
|
|
|
|
|
|
|
|
"""A temp loss cell."""
|
|
|
|
def __init__(self):
|
|
|
|
def __init__(self):
|
|
|
|
super(CrossEntropyWithLogits, self).__init__()
|
|
|
|
super(TempLoss, self).__init__()
|
|
|
|
self.transpose_fn = F.Transpose()
|
|
|
|
self.identity = F.identity()
|
|
|
|
self.reshape_fn = F.Reshape()
|
|
|
|
|
|
|
|
self.softmax_cross_entropy_loss = nn.SoftmaxCrossEntropyWithLogits()
|
|
|
|
|
|
|
|
self.cast = F.Cast()
|
|
|
|
|
|
|
|
def construct(self, logits, label):
|
|
|
|
def construct(self, logits, label):
|
|
|
|
# NCHW->NHWC
|
|
|
|
return self.identity(logits)
|
|
|
|
logits = self.transpose_fn(logits, (0, 2, 3, 1))
|
|
|
|
|
|
|
|
logits = self.cast(logits, mindspore.float32)
|
|
|
|
|
|
|
|
label = self.transpose_fn(label, (0, 2, 3, 1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = self.reduce_mean(self.softmax_cross_entropy_loss(self.reshape_fn(logits, (-1, 2)),
|
|
|
|
|
|
|
|
self.reshape_fn(label, (-1, 2))))
|
|
|
|
|
|
|
|
return self.get_loss(loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class dice_coeff(nn.Metric):
|
|
|
|
class dice_coeff(nn.Metric):
|
|
|
@ -64,16 +53,35 @@ class dice_coeff(nn.Metric):
|
|
|
|
|
|
|
|
|
|
|
|
def update(self, *inputs):
|
|
|
|
def update(self, *inputs):
|
|
|
|
if len(inputs) != 2:
|
|
|
|
if len(inputs) != 2:
|
|
|
|
raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
|
|
|
raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs)))
|
|
|
|
|
|
|
|
|
|
|
|
y_pred = self._convert_data(inputs[0])
|
|
|
|
|
|
|
|
y = self._convert_data(inputs[1])
|
|
|
|
y = self._convert_data(inputs[1])
|
|
|
|
|
|
|
|
|
|
|
|
self._samples_num += y.shape[0]
|
|
|
|
self._samples_num += y.shape[0]
|
|
|
|
y_pred = y_pred.transpose(0, 2, 3, 1)
|
|
|
|
|
|
|
|
y = y.transpose(0, 2, 3, 1)
|
|
|
|
y = y.transpose(0, 2, 3, 1)
|
|
|
|
y_pred = softmax(y_pred, axis=3)
|
|
|
|
b, h, w, c = y.shape
|
|
|
|
|
|
|
|
if b != 1:
|
|
|
|
|
|
|
|
raise ValueError('Batch size should be 1 when in evaluation.')
|
|
|
|
|
|
|
|
y = y.reshape((h, w, c))
|
|
|
|
|
|
|
|
if cfg_unet["eval_activate"].lower() == "softmax":
|
|
|
|
|
|
|
|
y_softmax = np.squeeze(self._convert_data(inputs[0][0]), axis=0)
|
|
|
|
|
|
|
|
if cfg_unet["eval_resize"]:
|
|
|
|
|
|
|
|
y_pred = []
|
|
|
|
|
|
|
|
for i in range(cfg_unet["num_classes"]):
|
|
|
|
|
|
|
|
y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, i] * 255), (w, h)) / 255)
|
|
|
|
|
|
|
|
y_pred = np.stack(y_pred, axis=-1)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
y_pred = y_softmax
|
|
|
|
|
|
|
|
elif cfg_unet["eval_activate"].lower() == "argmax":
|
|
|
|
|
|
|
|
y_argmax = np.squeeze(self._convert_data(inputs[0][1]), axis=0)
|
|
|
|
|
|
|
|
y_pred = []
|
|
|
|
|
|
|
|
for i in range(cfg_unet["num_classes"]):
|
|
|
|
|
|
|
|
if cfg_unet["eval_resize"]:
|
|
|
|
|
|
|
|
y_pred.append(cv2.resize(np.uint8(y_argmax == i), (w, h), interpolation=cv2.INTER_NEAREST))
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
y_pred.append(np.float32(y_argmax == i))
|
|
|
|
|
|
|
|
y_pred = np.stack(y_pred, axis=-1)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise ValueError('config eval_activate should be softmax or argmax.')
|
|
|
|
|
|
|
|
y_pred = y_pred.astype(np.float32)
|
|
|
|
inter = np.dot(y_pred.flatten(), y.flatten())
|
|
|
|
inter = np.dot(y_pred.flatten(), y.flatten())
|
|
|
|
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
|
|
|
|
union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten())
|
|
|
|
|
|
|
|
|
|
|
@ -104,14 +112,14 @@ def test_net(data_dir,
|
|
|
|
raise ValueError("Unsupported model: {}".format(cfg['model']))
|
|
|
|
raise ValueError("Unsupported model: {}".format(cfg['model']))
|
|
|
|
param_dict = load_checkpoint(ckpt_path)
|
|
|
|
param_dict = load_checkpoint(ckpt_path)
|
|
|
|
load_param_into_net(net, param_dict)
|
|
|
|
load_param_into_net(net, param_dict)
|
|
|
|
|
|
|
|
net = UnetEval(net)
|
|
|
|
criterion = CrossEntropyWithLogits()
|
|
|
|
|
|
|
|
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
|
|
|
|
if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei":
|
|
|
|
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, split=0.8)
|
|
|
|
valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False,
|
|
|
|
|
|
|
|
eval_resize=cfg["eval_resize"], split=0.8)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
|
|
|
|
_, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False,
|
|
|
|
do_crop=cfg['crop'], img_size=cfg['img_size'])
|
|
|
|
do_crop=cfg['crop'], img_size=cfg['img_size'])
|
|
|
|
model = Model(net, loss_fn=criterion, metrics={"dice_coeff": dice_coeff()})
|
|
|
|
model = Model(net, loss_fn=TempLoss(), metrics={"dice_coeff": dice_coeff()})
|
|
|
|
|
|
|
|
|
|
|
|
print("============== Starting Evaluating ============")
|
|
|
|
print("============== Starting Evaluating ============")
|
|
|
|
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]
|
|
|
|
eval_score = model.eval(valid_dataset, dataset_sink_mode=False)["dice_coeff"]
|
|
|
|