|
|
|
@ -119,9 +119,11 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True):
|
|
|
|
|
for param in network.get_parameters():
|
|
|
|
|
param.requires_grad = False
|
|
|
|
|
|
|
|
|
|
def define_net(config):
|
|
|
|
|
def define_net(config, is_training):
|
|
|
|
|
backbone_net = MobileNetV2Backbone()
|
|
|
|
|
activation = config.activation if not args.is_training else "None"
|
|
|
|
|
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, num_classes=config.num_classes)
|
|
|
|
|
net = mobilenet_v2(backbone_net, head_net, activation=activation)
|
|
|
|
|
activation = config.activation if not is_training else "None"
|
|
|
|
|
head_net = MobileNetV2Head(input_channel=backbone_net.out_channels,
|
|
|
|
|
num_classes=config.num_classes,
|
|
|
|
|
activation=activation)
|
|
|
|
|
net = mobilenet_v2(backbone_net, head_net)
|
|
|
|
|
return backbone_net, head_net, net
|
|
|
|
|