# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ """unet 310 infer.""" import os import argparse import numpy as np from src.data_loader import create_dataset from src.config import cfg_unet from scipy.special import softmax class dice_coeff(): def __init__(self): self.clear() def clear(self): self._dice_coeff_sum = 0 self._samples_num = 0 def update(self, *inputs): if len(inputs) != 2: raise ValueError('Mean dice coefficient need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) y_pred = inputs[0] y = np.array(inputs[1]) self._samples_num += y.shape[0] y_pred = y_pred.transpose(0, 2, 3, 1) y = y.transpose(0, 2, 3, 1) y_pred = softmax(y_pred, axis=3) inter = np.dot(y_pred.flatten(), y.flatten()) union = np.dot(y_pred.flatten(), y_pred.flatten()) + np.dot(y.flatten(), y.flatten()) single_dice_coeff = 2*float(inter)/float(union+1e-6) print("single dice coeff is:", single_dice_coeff) self._dice_coeff_sum += single_dice_coeff def eval(self): if self._samples_num == 0: raise RuntimeError('Total samples num must not be 0.') return self._dice_coeff_sum / float(self._samples_num) def test_net(data_dir, cross_valid_ind=1, cfg=None): _, valid_dataset = create_dataset(data_dir, 1, 1, False, cross_valid_ind, False, do_crop=cfg['crop'], img_size=cfg['img_size']) labels_list = [] for data in valid_dataset: labels_list.append(data[1].asnumpy()) return labels_list def get_args(): parser = argparse.ArgumentParser(description='Test the UNet on images and target masks', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-d', '--data_url', dest='data_url', type=str, default='data/', help='data directory') parser.add_argument('-p', '--rst_path', dest='rst_path', type=str, default='./result_Files/', help='infer result path') return parser.parse_args() if __name__ == '__main__': args = get_args() label_list = test_net(data_dir=args.data_url, cross_valid_ind=cfg_unet['cross_valid_ind'], cfg=cfg_unet) rst_path = args.rst_path metrics = dice_coeff() for j in range(len(os.listdir(rst_path))): file_name = rst_path + "ISBI_test_bs_1_" + str(j) + "_0" + ".bin" output = np.fromfile(file_name, np.float32).reshape(1, 2, 576, 576) label = label_list[j] metrics.update(output, label) print("Cross valid dice coeff is: ", metrics.eval())