From 888a5412a5bae3d08cee1cd109747dc8fccd21ee Mon Sep 17 00:00:00 2001 From: huzhifeng Date: Mon, 21 Sep 2020 15:46:36 +0800 Subject: [PATCH] modify ghostnet for hub --- .../resnet50_adv_pruning/mindpsore_hub_conf.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/model_zoo/research/cv/resnet50_adv_pruning/mindpsore_hub_conf.py b/model_zoo/research/cv/resnet50_adv_pruning/mindpsore_hub_conf.py index e38b765e6c..9dd975ab93 100644 --- a/model_zoo/research/cv/resnet50_adv_pruning/mindpsore_hub_conf.py +++ b/model_zoo/research/cv/resnet50_adv_pruning/mindpsore_hub_conf.py @@ -14,9 +14,21 @@ # ============================================================================ """hub config.""" from src.resnet_imgnet import resnet50 +from mindspore import Tensor +import numpy as np -def create_network(name, *args, **kwargs): - if name == 'resnet-0.65x': - return resnet50(*args, **kwargs) +def get_index(filename): + index = [] + with open(filename) as fr: + for line in fr: + ind = Tensor((np.array(line.strip('\n').split(' ')[:-1])).astype(np.int32).reshape(-1, 1)) + index.append(ind) + return index + + +def create_network(name, rate=0.65, index_filename='index.txt', **kwargs): + index = get_index(index_filename) + if name == 'resnet50-0.65x': + return resnet50(rate=rate, index=index, **kwargs) raise NotImplementedError(f"{name} is not implemented in the repo")