|
|
|
@ -21,6 +21,7 @@ import cv2
|
|
|
|
|
from mindspore import Tensor
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
import mindspore.ops as ops
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
from src.nets import net_factory
|
|
|
|
@ -47,6 +48,8 @@ def parse_args():
|
|
|
|
|
parser.add_argument('--model', type=str, default='deeplab_v3_s16', help='select model')
|
|
|
|
|
parser.add_argument('--freeze_bn', action='store_true', default=False, help='freeze bn')
|
|
|
|
|
parser.add_argument('--ckpt_path', type=str, default='', help='model to evaluate')
|
|
|
|
|
parser.add_argument("--input_format", type=str, choices=["NCHW", "NHWC"], default="NCHW",
|
|
|
|
|
help="NCHW or NHWC")
|
|
|
|
|
|
|
|
|
|
args, _ = parser.parse_known_args()
|
|
|
|
|
return args
|
|
|
|
@ -70,12 +73,16 @@ def resize_long(img, long_size=513):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BuildEvalNetwork(nn.Cell):
|
|
|
|
|
def __init__(self, network):
|
|
|
|
|
def __init__(self, network, input_format="NCHW"):
|
|
|
|
|
super(BuildEvalNetwork, self).__init__()
|
|
|
|
|
self.network = network
|
|
|
|
|
self.softmax = nn.Softmax(axis=1)
|
|
|
|
|
self.transpose = ops.Transpose()
|
|
|
|
|
self.format = input_format
|
|
|
|
|
|
|
|
|
|
def construct(self, input_data):
|
|
|
|
|
if self.format == "NHWC":
|
|
|
|
|
input_data = self.transpose(input_data, (0, 3, 1, 2))
|
|
|
|
|
output = self.network(input_data)
|
|
|
|
|
output = self.softmax(output)
|
|
|
|
|
return output
|
|
|
|
@ -96,7 +103,6 @@ def pre_process(args, img_, crop_size=513):
|
|
|
|
|
pad_w = crop_size - img_.shape[1]
|
|
|
|
|
if pad_h > 0 or pad_w > 0:
|
|
|
|
|
img_ = cv2.copyMakeBorder(img_, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
|
|
|
|
|
|
|
|
|
|
# hwc to chw
|
|
|
|
|
img_ = img_.transpose((2, 0, 1))
|
|
|
|
|
return img_, resize_h, resize_w
|
|
|
|
@ -162,7 +168,7 @@ def net_eval():
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError('model [{:s}] not recognized'.format(args.model))
|
|
|
|
|
|
|
|
|
|
eval_net = BuildEvalNetwork(network)
|
|
|
|
|
eval_net = BuildEvalNetwork(network, args.input_format)
|
|
|
|
|
|
|
|
|
|
# load model
|
|
|
|
|
param_dict = load_checkpoint(args.ckpt_path)
|
|
|
|
|