From 620e06e55d0282822b1fe9d00c5efd253f6b1b0c Mon Sep 17 00:00:00 2001 From: lihongkang <[lihongkang1@huawei.com]> Date: Sat, 27 Mar 2021 20:20:05 +0800 Subject: [PATCH] add unet++ 310 mindir infer and update unet 310 infer code --- .../cv/unet/ascend310_infer/src/main.cc | 26 +++++- model_zoo/official/cv/unet/export.py | 2 + model_zoo/official/cv/unet/postprocess.py | 83 ++++++++++++++----- model_zoo/official/cv/unet/preprocess.py | 61 +++++++++++++- .../official/cv/unet/scripts/run_infer_310.sh | 13 +-- 5 files changed, 155 insertions(+), 30 deletions(-) diff --git a/model_zoo/official/cv/unet/ascend310_infer/src/main.cc b/model_zoo/official/cv/unet/ascend310_infer/src/main.cc index 66324e2495..34a043fd15 100644 --- a/model_zoo/official/cv/unet/ascend310_infer/src/main.cc +++ b/model_zoo/official/cv/unet/ascend310_infer/src/main.cc @@ -42,11 +42,17 @@ using mindspore::MSTensor; using mindspore::ModelType; using mindspore::GraphCell; using mindspore::kSuccess; +using mindspore::dataset::vision::Decode; +using mindspore::dataset::vision::SwapRedBlue; +using mindspore::dataset::vision::Normalize; +using mindspore::dataset::vision::Resize; +using mindspore::dataset::vision::HWC2CHW; DEFINE_string(mindir_path, "", "mindir path"); DEFINE_string(dataset_path, ".", "dataset path"); DEFINE_int32(device_id, 0, "device id"); +DEFINE_string(need_preprocess, "n", "need preprocess or not"); int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -78,6 +84,14 @@ int main(int argc, char **argv) { std::map costTime_map; size_t size = all_files.size(); + + auto decode(new Decode()); + auto swapredblue(new SwapRedBlue()); + auto resize(new Resize({96, 96})); + auto normalize(new Normalize({127.5, 127.5, 127.5}, {127.5, 127.5, 127.5})); + auto hwc2chw(new HWC2CHW()); + Execute preprocess({decode, swapredblue, resize, normalize, hwc2chw}); + for (size_t i = 0; i < size; ++i) { struct timeval start = {0}; struct timeval end = {0}; @@ -86,7 +100,17 @@ int main(int argc, char **argv) { std::vector inputs; std::vector outputs; std::cout << "Start predict input files:" << all_files[i] << std::endl; - auto img = ReadFileToTensor(all_files[i]); + + auto img = MSTensor(); + if (FLAGS_need_preprocess == "y") { + ret = preprocess(ReadFileToTensor(all_files[i]), &img); + if (ret != kSuccess) { + std::cout << "preprocess " << all_files[i] << " failed." << std::endl; + return 1; + } + } else { + img = ReadFileToTensor(all_files[i]); + } inputs.emplace_back(model_inputs[0].Name(), model_inputs[0].DataType(), model_inputs[0].Shape(), img.Data().get(), img.DataSize()); diff --git a/model_zoo/official/cv/unet/export.py b/model_zoo/official/cv/unet/export.py index 923e3b1678..b4af88f316 100644 --- a/model_zoo/official/cv/unet/export.py +++ b/model_zoo/official/cv/unet/export.py @@ -21,6 +21,7 @@ from mindspore import Tensor, export, load_checkpoint, load_param_into_net, cont from src.unet_medical.unet_model import UNetMedical from src.unet_nested import NestedUNet, UNet from src.config import cfg_unet as cfg +from src.utils import UnetEval parser = argparse.ArgumentParser(description='unet export') parser.add_argument("--device_id", type=int, default=0, help="Device id") @@ -52,5 +53,6 @@ if __name__ == "__main__": param_dict = load_checkpoint(args.ckpt_file) # load the parameter into net load_param_into_net(net, param_dict) + net = UnetEval(net) input_data = Tensor(np.ones([args.batch_size, cfg["num_channels"], args.height, args.width]).astype(np.float32)) export(net, input_data, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/cv/unet/postprocess.py b/model_zoo/official/cv/unet/postprocess.py index bc179bd39d..530d3ae19e 100644 --- a/model_zoo/official/cv/unet/postprocess.py +++ b/model_zoo/official/cv/unet/postprocess.py @@ -15,53 +15,77 @@ """unet 310 infer.""" import os import argparse +import cv2 import numpy as np -from src.data_loader import create_dataset +from src.data_loader import create_dataset, create_cell_nuclei_dataset from src.config import cfg_unet -from scipy.special import softmax class dice_coeff(): def __init__(self): self.clear() - def clear(self): self._dice_coeff_sum = 0 + self._iou_sum = 0 self._samples_num = 0 def update(self, *inputs): if len(inputs) != 2: - raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) - - y_pred = inputs[0] + raise ValueError('Need 2 inputs ((y_softmax, y_argmax), y), but got {}'.format(len(inputs))) y = np.array(inputs[1]) - self._samples_num += y.shape[0] - y_pred = y_pred.transpose(0, 2, 3, 1) y = y.transpose(0, 2, 3, 1) - y_pred = softmax(y_pred, axis=3) - + b, h, w, c = y.shape + if b != 1: + raise ValueError('Batch size should be 1 when in evaluation.') + y = y.reshape((h, w, c)) + if cfg_unet["eval_activate"].lower() == "softmax": + y_softmax = np.squeeze(inputs[0][0], axis=0) + if cfg_unet["eval_resize"]: + y_pred = [] + for m in range(cfg_unet["num_classes"]): + y_pred.append(cv2.resize(np.uint8(y_softmax[:, :, m] * 255), (w, h)) / 255) + y_pred = np.stack(y_pred, axis=-1) + else: + y_pred = y_softmax + elif cfg_unet["eval_activate"].lower() == "argmax": + y_argmax = np.squeeze(inputs[0][1], axis=0) + y_pred = [] + for n in range(cfg_unet["num_classes"]): + if cfg_unet["eval_resize"]: + y_pred.append(cv2.resize(np.uint8(y_argmax == n), (w, h), interpolation=cv2.INTER_NEAREST)) + else: + y_pred.append(np.float32(y_argmax == n)) + y_pred = np.stack(y_pred, axis=-1) + else: + raise ValueError('config eval_activate should be softmax or argmax.') + y_pred = y_pred.astype(np.float32) inter = np.dot(y_pred.flatten(), y.flatten()) union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) single_dice_coeff = 2*float(inter)/float(union+1e-6) - print("single dice coeff is:", single_dice_coeff) + single_iou = single_dice_coeff / (2 - single_dice_coeff) + print("single dice coeff is: {}, IOU is: {}".format(single_dice_coeff, single_iou)) self._dice_coeff_sum += single_dice_coeff + self._iou_sum += single_iou def eval(self): if self._samples_num == 0: raise RuntimeError('Total samples num must not be 0.') - - return self._dice_coeff_sum / float(self._samples_num) + return (self._dice_coeff_sum / float(self._samples_num), self._iou_sum / float(self._samples_num)) def test_net(data_dir, cross_valid_ind=1, cfg=None): - _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], - img_size=cfg['img_size']) + if 'dataset' in cfg and cfg['dataset'] == "Cell_nuclei": + valid_dataset = create_cell_nuclei_dataset(data_dir, cfg['img_size'], 1, 1, is_train=False, + eval_resize=cfg["eval_resize"], split=0.8) + else: + _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], + img_size=cfg['img_size']) labels_list = [] for data in valid_dataset: @@ -89,10 +113,25 @@ if __name__ == '__main__': rst_path = args.rst_path metrics = dice_coeff() - for j in range(len(os.listdir(rst_path))): - file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" - output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576) - label = label_list[j] - metrics.update(output, label) - - print("Cross valid dice coeff is: ", metrics.eval()) + if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": + for i, bin_name in enumerate(os.listdir('./preprocess_Result/')): + bin_name_softmax = bin_name.replace(".png", "") + "_0.bin" + bin_name_argmax = bin_name.replace(".png", "") + "_1.bin" + file_name_sof = rst_path + bin_name_softmax + file_name_arg = rst_path + bin_name_argmax + softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 96, 96, 2) + argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 96, 96) + label = label_list[i] + metrics.update((softmax_out, argmax_out), label) + else: + for j in range(len(os.listdir('./preprocess_Result/'))): + file_name_sof = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" + file_name_arg = rst_path + "ISBI_test_bs_1_" + str(j) + "_1" + ".bin" + softmax_out = np.fromfile(file_name_sof, np.float32).reshape(1, 576, 576, 2) + argmax_out = np.fromfile(file_name_arg, np.float32).reshape(1, 576, 576) + label = label_list[j] + metrics.update((softmax_out, argmax_out), label) + + eval_score = metrics.eval() + print("============== Cross valid dice coeff is:", eval_score[0]) + print("============== Cross valid IOU is:", eval_score[1]) diff --git a/model_zoo/official/cv/unet/preprocess.py b/model_zoo/official/cv/unet/preprocess.py index 9a87b5bd07..28f6cdc623 100644 --- a/model_zoo/official/cv/unet/preprocess.py +++ b/model_zoo/official/cv/unet/preprocess.py @@ -14,6 +14,10 @@ # ============================================================================ """unet 310 infer preprocess dataset""" import argparse +import os +import numpy as np +import cv2 + from src.data_loader import create_dataset from src.config import cfg_unet @@ -29,6 +33,56 @@ def preprocess_dataset(data_dir, result_path, cross_valid_ind=1, cfg=None): data[0].asnumpy().tofile(file_path) +class CellNucleiDataset: + """ + Cell nuclei dataset preprocess class. + """ + def __init__(self, data_dir, repeat, result_path, is_train=False, split=0.8): + self.data_dir = data_dir + self.img_ids = sorted(next(os.walk(self.data_dir))[1]) + self.train_ids = self.img_ids[:int(len(self.img_ids) * split)] * repeat + np.random.shuffle(self.train_ids) + self.val_ids = self.img_ids[int(len(self.img_ids) * split):] + self.is_train = is_train + self.result_path = result_path + self._preprocess_dataset() + + def _preprocess_dataset(self): + for img_id in self.val_ids: + path = os.path.join(self.data_dir, img_id) + img = cv2.imread(os.path.join(path, "images", img_id + ".png")) + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + img = np.concatenate([img, img, img], axis=-1) + mask = [] + for mask_file in next(os.walk(os.path.join(path, "masks")))[2]: + mask_ = cv2.imread(os.path.join(path, "masks", mask_file), cv2.IMREAD_GRAYSCALE) + mask.append(mask_) + mask = np.max(mask, axis=0) + cv2.imwrite(os.path.join(self.result_path, img_id + ".png"), img) + + def _read_img_mask(self, img_id): + path = os.path.join(self.data_dir, img_id) + img = cv2.imread(os.path.join(path, "image.png")) + mask = cv2.imread(os.path.join(path, "mask.png"), cv2.IMREAD_GRAYSCALE) + return img, mask + + def __getitem__(self, index): + if self.is_train: + return self._read_img_mask(self.train_ids[index]) + return self._read_img_mask(self.val_ids[index]) + + @property + def column_names(self): + column_names = ['image', 'mask'] + return column_names + + def __len__(self): + if self.is_train: + return len(self.train_ids) + return len(self.val_ids) + + def get_args(): parser = argparse.ArgumentParser(description='Preprocess the UNet dataset ', formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -42,5 +96,8 @@ def get_args(): if __name__ == '__main__': args = get_args() - preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, result_path= - args.result_path) + if 'dataset' in cfg_unet and cfg_unet['dataset'] == "Cell_nuclei": + cell_dataset = CellNucleiDataset(args.data_url, 1, args.result_path, False, 0.8) + else: + preprocess_dataset(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet, + result_path=args.result_path) diff --git a/model_zoo/official/cv/unet/scripts/run_infer_310.sh b/model_zoo/official/cv/unet/scripts/run_infer_310.sh index 2e60fd5c07..3df2bfcc12 100644 --- a/model_zoo/official/cv/unet/scripts/run_infer_310.sh +++ b/model_zoo/official/cv/unet/scripts/run_infer_310.sh @@ -14,9 +14,10 @@ # limitations under the License. # ============================================================================ -if [[ $# -lt 2 || $# -gt 3 ]]; then - echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] - DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero" +if [[ $# -lt 3 || $# -gt 4 ]]; then + echo "Usage: bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [DEVICE_ID] [NEED_PREPROCESS] + DEVICE_ID is optional, it can be set by environment variable device_id, otherwise the value is zero. + NEED_PREPROCESS means weather need preprocess or not, it's value is 'y' or 'n'." exit 1 fi @@ -29,7 +30,7 @@ get_real_path(){ } model=$(get_real_path $1) data_path=$(get_real_path $2) -if [ $# == 3 ]; then +if [ $# == 4 ]; then device_id=$3 if [ -z $device_id ]; then device_id=0 @@ -37,10 +38,12 @@ if [ $# == 3 ]; then device_id=$device_id fi fi +need_preprocess=$4 echo "mindir name: "$model echo "dataset path: "$data_path echo "device id: "$device_id +echo "need preprocess or not: "$need_preprocess export ASCEND_HOME=/usr/local/Ascend/ if [ -d ${ASCEND_HOME}/ascend-toolkit ]; then @@ -85,7 +88,7 @@ function infer() fi mkdir result_Files mkdir time_Result - ../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id &> infer.log + ../ascend310_infer/src/main --mindir_path=$model --dataset_path=./preprocess_Result/ --device_id=$device_id --need_preprocess=$need_preprocess &> infer.log } function cal_acc()