!6636 fix mobilenetv2 script error

Merge pull request !6636 from zhaoting/hub
pull/6636/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0037afee74

@ -28,7 +28,7 @@ from src.utils import switch_precision, set_context
if __name__ == '__main__':
args_opt = eval_parse_args()
config = set_config(args_opt)
backbone_net, head_net, net = define_net(config)
backbone_net, head_net, net = define_net(config, args_opt.is_training)
#load the trained checkpoint file to the net for evaluation
if args_opt.head_ckpt:

@ -119,9 +119,11 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True):
for param in network.get_parameters():
param.requires_grad = False
def define_net(config):
def define_net(config, is_training):
backbone_net = MobileNetV2Backbone()
activation = config.activation if not args.is_training else "None"
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
net = mobilenet_v2(backbone_net, head_net, activation=activation)
activation = config.activation if not is_training else "None"
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,
num_classes=config.num_classes,
activation=activation)
net = mobilenet_v2(backbone_net, head_net)
return backbone_net, head_net, net

@ -51,7 +51,7 @@ if __name__ == '__main__':
context_device_init(config)
# define network
backbone_net, head_net, net = define_net(config)
backbone_net, head_net, net = define_net(config, args_opt.is_training)
# load the ckpt file to the network for fine tune or incremental leaning
if args_opt.pretrain_ckpt:

Loading…
Cancel
Save