From a32c5fbc9245307f9972966bbf92868cfb832f9e Mon Sep 17 00:00:00 2001 From: jzg Date: Tue, 22 Sep 2020 19:51:55 +0800 Subject: [PATCH] amend deeplabv3 hub config --- model_zoo/official/cv/deeplabv3/mindspore_hub_conf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model_zoo/official/cv/deeplabv3/mindspore_hub_conf.py b/model_zoo/official/cv/deeplabv3/mindspore_hub_conf.py index 70e8faec43..9a35473146 100644 --- a/model_zoo/official/cv/deeplabv3/mindspore_hub_conf.py +++ b/model_zoo/official/cv/deeplabv3/mindspore_hub_conf.py @@ -17,11 +17,11 @@ from src.nets import net_factory def create_network(name, *args, **kwargs): freeze_bn = True - num_classes = 21 + num_classes = kwargs["num_classes"] if name == 'deeplab_v3_s16': deeplab_v3_s16_network = net_factory.nets_map["deeplab_v3_s16"]('eval', num_classes, 16, freeze_bn) - return deeplab_v3_s16_network(*args, **kwargs) + return deeplab_v3_s16_network if name == 'deeplab_v3_s8': deeplab_v3_s8_network = net_factory.nets_map["deeplab_v3_s8"]('eval', num_classes, 8, freeze_bn) - return deeplab_v3_s8_network(*args, **kwargs) + return deeplab_v3_s8_network raise NotImplementedError(f"{name} is not implemented in the repo")