From 3e9b97e7e9f4772342d23a5088199d0616249ffb Mon Sep 17 00:00:00 2001 From: zhaoting Date: Mon, 21 Sep 2020 19:20:31 +0800 Subject: [PATCH] add moboilenetv3 and resnext50 hub --- model_zoo/official/cv/mobilenetv2/README.md | 1 + .../official/cv/mobilenetv2/src/models.py | 2 +- model_zoo/official/cv/mobilenetv3/Readme.md | 1 + model_zoo/official/cv/mobilenetv3/eval.py | 2 +- .../cv/mobilenetv3/mindspore_hub_conf.py | 25 +++++++++++++++ .../cv/mobilenetv3/src/mobilenetV3.py | 30 +++++++++++++----- model_zoo/official/cv/resnext50/README.md | 3 +- model_zoo/official/cv/resnext50/eval.py | 2 +- .../cv/resnext50/mindspore_hub_conf.py | 22 +++++++++++++ .../cv/resnext50/src/image_classification.py | 31 ++++++++++++++----- model_zoo/official/cv/resnext50/train.py | 2 +- model_zoo/official/cv/ssd/README.md | 3 +- 12 files changed, 103 insertions(+), 21 deletions(-) create mode 100644 model_zoo/official/cv/mobilenetv3/mindspore_hub_conf.py create mode 100644 model_zoo/official/cv/resnext50/mindspore_hub_conf.py diff --git a/model_zoo/official/cv/mobilenetv2/README.md b/model_zoo/official/cv/mobilenetv2/README.md index 12f1d12d0e..ebbfd87927 100644 --- a/model_zoo/official/cv/mobilenetv2/README.md +++ b/model_zoo/official/cv/mobilenetv2/README.md @@ -77,6 +77,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil │ ├──utils.py # utils to load ckpt_file for fine tune or incremental learn ├── train.py # training script ├── eval.py # evaluation script + ├── mindspore_hub_conf.py # mindspore hub interface ``` ## [Training process](#contents) diff --git a/model_zoo/official/cv/mobilenetv2/src/models.py b/model_zoo/official/cv/mobilenetv2/src/models.py index 4b391adbe5..d74a97f755 100644 --- a/model_zoo/official/cv/mobilenetv2/src/models.py +++ b/model_zoo/official/cv/mobilenetv2/src/models.py @@ -119,7 +119,7 @@ def load_ckpt(network, pretrain_ckpt_path, trainable=True): for param in network.get_parameters(): param.requires_grad = False -def define_net(config, is_training): +def define_net(config, is_training=True): backbone_net = MobileNetV2Backbone() activation = config.activation if not is_training else "None" head_net = MobileNetV2Head(input_channel=backbone_net.out_channels, diff --git a/model_zoo/official/cv/mobilenetv3/Readme.md b/model_zoo/official/cv/mobilenetv3/Readme.md index e4fcc86773..5fb65d2513 100644 --- a/model_zoo/official/cv/mobilenetv3/Readme.md +++ b/model_zoo/official/cv/mobilenetv3/Readme.md @@ -69,6 +69,7 @@ Dataset used: [imagenet](http://www.image-net.org/) │ ├──mobilenetV3.py # MobileNetV3 architecture ├── train.py # training script ├── eval.py # evaluation script + ├── mindspore_hub_conf.py # mindspore hub interface ``` ## [Training process](#contents) diff --git a/model_zoo/official/cv/mobilenetv3/eval.py b/model_zoo/official/cv/mobilenetv3/eval.py index d7e076490f..1babebf1dc 100644 --- a/model_zoo/official/cv/mobilenetv3/eval.py +++ b/model_zoo/official/cv/mobilenetv3/eval.py @@ -42,7 +42,7 @@ if __name__ == '__main__': raise ValueError("Unsupported device_target.") loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - net = mobilenet_v3_large(num_classes=config.num_classes) + net = mobilenet_v3_large(num_classes=config.num_classes, activation="Softmax") dataset = create_dataset(dataset_path=args_opt.dataset_path, do_train=False, diff --git a/model_zoo/official/cv/mobilenetv3/mindspore_hub_conf.py b/model_zoo/official/cv/mobilenetv3/mindspore_hub_conf.py new file mode 100644 index 0000000000..54eed56e6d --- /dev/null +++ b/model_zoo/official/cv/mobilenetv3/mindspore_hub_conf.py @@ -0,0 +1,25 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.mobilenetV3 import mobilenet_v3_large, mobilenet_v3_small + +def create_network(name, *args, **kwargs): + if name == "mobilenetv3_large": + net = mobilenet_v3_large(*args, **kwargs) + elif name == "mobilenetv3_small": + net = mobilenet_v3_small(*args, **kwargs) + else: + raise NotImplementedError(f"{name} is not implemented in the repo") + return net diff --git a/model_zoo/official/cv/mobilenetv3/src/mobilenetV3.py b/model_zoo/official/cv/mobilenetv3/src/mobilenetV3.py index 069efa514a..0bf59b9fb1 100644 --- a/model_zoo/official/cv/mobilenetv3/src/mobilenetV3.py +++ b/model_zoo/official/cv/mobilenetv3/src/mobilenetV3.py @@ -246,7 +246,8 @@ class MobileNetV3(nn.Cell): >>> MobileNetV3(num_classes=1000) """ - def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8): + def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., + round_nearest=8, include_top=True, activation="None"): super(MobileNetV3, self).__init__() self.cfgs = model_cfgs['cfg'] self.inplanes = 16 @@ -285,19 +286,34 @@ class MobileNetV3(nn.Cell): # make it nn.CellList self.features = nn.SequentialCell(self.features) - self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], - out_channels=num_classes, - kernel_size=1, has_bias=True, pad_mode='pad') - self.squeeze = P.Squeeze(axis=(2, 3)) + self.include_top = include_top + self.need_activation = False + if self.include_top: + self.output = nn.Conv2d(in_channels=model_cfgs['cls_ch_expand'], + out_channels=num_classes, + kernel_size=1, has_bias=True, pad_mode='pad') + self.squeeze = P.Squeeze(axis=(2, 3)) + if activation != "None": + self.need_activation = True + if activation == "Sigmoid": + self.activation = P.Sigmoid() + elif activation == "Softmax": + self.activation = P.Softmax() + else: + raise NotImplementedError(f"The activation {activation} not in [Sigmoid, Softmax].") self._initialize_weights() def construct(self, x): x = self.features(x) - x = self.output(x) - x = self.squeeze(x) + if self.include_top: + x = self.output(x) + x = self.squeeze(x) + if self.need_activation: + x = self.activation(x) return x + def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1): mid_planes = exp_ch out_planes = out_channel diff --git a/model_zoo/official/cv/resnext50/README.md b/model_zoo/official/cv/resnext50/README.md index b1ec1fbc6e..eb57e651af 100644 --- a/model_zoo/official/cv/resnext50/README.md +++ b/model_zoo/official/cv/resnext50/README.md @@ -96,7 +96,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil ├─warmup_cosine_annealing.py # learning rate each step ├─warmup_step_lr.py # warmup step learning rate ├─eval.py # eval net - └─train.py # train net + ├──train.py # train net + ├──mindspore_hub_conf.py # mindspore hub interface ``` diff --git a/model_zoo/official/cv/resnext50/eval.py b/model_zoo/official/cv/resnext50/eval.py index 9dafa5070a..7247665c5e 100644 --- a/model_zoo/official/cv/resnext50/eval.py +++ b/model_zoo/official/cv/resnext50/eval.py @@ -201,7 +201,7 @@ def test(cloud_args=None): max_epoch=1, rank=args.rank, group_size=args.group_size, mode='eval') eval_dataloader = de_dataset.create_tuple_iterator(output_numpy=True) - network = get_network(args.backbone, args.num_classes, platform=args.platform) + network = get_network(args.backbone, num_classes=args.num_classes, platform=args.platform) if network is None: raise NotImplementedError('not implement {}'.format(args.backbone)) diff --git a/model_zoo/official/cv/resnext50/mindspore_hub_conf.py b/model_zoo/official/cv/resnext50/mindspore_hub_conf.py new file mode 100644 index 0000000000..b8bd401518 --- /dev/null +++ b/model_zoo/official/cv/resnext50/mindspore_hub_conf.py @@ -0,0 +1,22 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""hub config.""" +from src.image_classification import get_network + +def create_network(name, *args, **kwargs): + if name == "renext50": + get_network("renext50", *args, **kwargs) + return net + raise NotImplementedError(f"{name} is not implemented in the repo") diff --git a/model_zoo/official/cv/resnext50/src/image_classification.py b/model_zoo/official/cv/resnext50/src/image_classification.py index fabc4aeeac..37e17caad8 100644 --- a/model_zoo/official/cv/resnext50/src/image_classification.py +++ b/model_zoo/official/cv/resnext50/src/image_classification.py @@ -31,31 +31,46 @@ class ImageClassificationNetwork(nn.Cell): Returns: Tensor, output tensor. """ - def __init__(self, backbone, head): + def __init__(self, backbone, head, include_top=True, activation="None"): super(ImageClassificationNetwork, self).__init__() self.backbone = backbone - self.head = head + self.include_top = include_top + self.need_activation = False + if self.include_top: + self.head = head + if activation != "None": + self.need_activation = True + if activation == "Sigmoid": + self.activation = P.Sigmoid() + elif activation == "Softmax": + self.activation = P.Softmax() + else: + raise NotImplementedError(f"The activation {activation} not in [Sigmoid, Softmax].") def construct(self, x): x = self.backbone(x) - x = self.head(x) + if self.include_top: + x = self.head(x) + if self.need_activation: + x = self.activation(x) return x + class Resnet(ImageClassificationNetwork): """ Resnet architecture. Args: backbone_name (string): backbone. - num_classes (int): number of classes. + num_classes (int): number of classes, Default is 1000. Returns: Resnet. """ - def __init__(self, backbone_name, num_classes, platform="Ascend"): + def __init__(self, backbone_name, num_classes=1000, platform="Ascend", include_top=True, activation="None"): self.backbone_name = backbone_name backbone = backbones.__dict__[self.backbone_name](platform=platform) out_channels = backbone.get_out_channels() head = heads.CommonHead(num_classes=num_classes, out_channels=out_channels) - super(Resnet, self).__init__(backbone, head) + super(Resnet, self).__init__(backbone, head, include_top, activation) default_recurisive_init(self) @@ -79,7 +94,7 @@ class Resnet(ImageClassificationNetwork): -def get_network(backbone_name, num_classes, platform="Ascend"): +def get_network(backbone_name, **kwargs): if backbone_name in ['resnext50']: - return Resnet(backbone_name, num_classes, platform) + return Resnet(backbone_name, **kwargs) return None diff --git a/model_zoo/official/cv/resnext50/train.py b/model_zoo/official/cv/resnext50/train.py index 983ce37faa..1611887210 100644 --- a/model_zoo/official/cv/resnext50/train.py +++ b/model_zoo/official/cv/resnext50/train.py @@ -213,7 +213,7 @@ def train(cloud_args=None): # network args.logger.important_info('start create network') # get network and init - network = get_network(args.backbone, args.num_classes, platform=args.platform) + network = get_network(args.backbone, num_classes=args.num_classes, platform=args.platform) if network is None: raise NotImplementedError('not implement {}'.format(args.backbone)) diff --git a/model_zoo/official/cv/ssd/README.md b/model_zoo/official/cv/ssd/README.md index f1e40c4e22..2dc5fd19e7 100644 --- a/model_zoo/official/cv/ssd/README.md +++ b/model_zoo/official/cv/ssd/README.md @@ -114,7 +114,8 @@ sh run_eval.sh [DATASET] [CHECKPOINT_PATH] [DEVICE_ID] ├─ lr_schedule.py ## learning ratio generator └─ ssd.py ## ssd architecture ├─ eval.py ## eval scripts - └─ train.py ## train scripts + ├─ train.py ## train scripts + ├── mindspore_hub_conf.py # mindspore hub interface ``` ## [Script Parameters](#contents)