From a7e881f312184090cd8e483b57b22b7f0d518203 Mon Sep 17 00:00:00 2001 From: zhaoting Date: Mon, 21 Sep 2020 15:44:31 +0800 Subject: [PATCH] fix mobilenetv2 script error --- model_zoo/official/cv/mobilenetv2/eval.py | 2 +- model_zoo/official/cv/mobilenetv2/src/models.py | 10 ++++++---- model_zoo/official/cv/mobilenetv2/train.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/model_zoo/official/cv/mobilenetv2/eval.py b/model_zoo/official/cv/mobilenetv2/eval.py index a20cdbef1f..a23b0f02ce 100644 --- a/model_zoo/official/cv/mobilenetv2/eval.py +++ b/model_zoo/official/cv/mobilenetv2/eval.py @@ -28,7 +28,7 @@ from src.utils import switch_precision, set_context if __name__ == '__main__': args_opt = eval_parse_args() config = set_config(args_opt) - backbone_net, head_net, net = define_net(config) + backbone_net, head_net, net = define_net(config, args_opt.is_training) #load the trained checkpoint file to the net for evaluation if args_opt.head_ckpt: diff --git a/model_zoo/official/cv/mobilenetv2/src/models.py b/model_zoo/official/cv/mobilenetv2/src/models.py index 48b1f5b2c2..4b391adbe5 100644 --- a/model_zoo/official/cv/mobilenetv2/src/models.py +++ b/model_zoo/official/cv/mobilenetv2/src/models.py @@ -119,9 +119,11 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True): for param in network.get_parameters(): param.requires_grad = False -def define_net(config): +def define_net(config, is_training): backbone_net = MobileNetV2Backbone() - activation = config.activation if not args.is_training else "None" - head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes) - net = mobilenet_v2(backbone_net, head_net, activation=activation) + activation = config.activation if not is_training else "None" + head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, + num_classes=config.num_classes, + activation=activation) + net = mobilenet_v2(backbone_net, head_net) return backbone_net, head_net, net diff --git a/model_zoo/official/cv/mobilenetv2/train.py b/model_zoo/official/cv/mobilenetv2/train.py index 3c5e9878f3..01a629dc0d 100644 --- a/model_zoo/official/cv/mobilenetv2/train.py +++ b/model_zoo/official/cv/mobilenetv2/train.py @@ -51,7 +51,7 @@ if __name__ == '__main__': context_device_init(config) # define network - backbone_net, head_net, net = define_net(config) + backbone_net, head_net, net = define_net(config, args_opt.is_training) # load the ckpt file to the network for fine tune or incremental leaning if args_opt.pretrain_ckpt: