!5444 Support manual convert to quantative network of resnet

Merge pull request !5444 from chenfei_mindspore/add-manual-quant-network-of-resnet
pull/5444/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 135cfc6adf

@ -252,13 +252,14 @@ def without_fold_batchnorm(weight, cell_quant):
return weight, bias
def load_nonquant_param_into_quant_net(quant_model, params_dict):
def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
"""
load fp32 model parameters to quantization model.
Args:
quant_model: quantization model
params_dict: f32 param
quant_model: quantization model.
params_dict: f32 param.
quant_new_params:parameters that exist in quantative network but not in unquantative network.
Returns:
None
@ -277,6 +278,8 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict):
for name, param in quant_model.parameters_and_names():
key_name = name.split(".")[-1]
if key_name not in iterable_dict.keys():
if quant_new_params is not None and key_name in quant_new_params:
continue
raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
value_param = next(iterable_dict[key_name], None)
if value_param is not None:

@ -20,7 +20,8 @@ import argparse
from src.config import config_quant
from src.dataset import create_dataset
from src.crossentropy import CrossEntropy
from models.resnet_quant import resnet50_quant
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
from mindspore import context
from mindspore.train.model import Model

@ -209,7 +209,7 @@ class ResNet(nn.Cell):
return out
def resnet50_quant(class_num=10001):
def resnet50_quant(class_num=10):
"""
Get ResNet50 neural network.

@ -32,7 +32,8 @@ from mindspore.communication.management import init
import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from models.resnet_quant import resnet50_quant
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
from src.dataset import create_dataset
from src.lr_generator import get_lr
from src.config import config_quant
@ -86,7 +87,7 @@ if __name__ == '__main__':
# weight init and load checkpoint file
if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained)
load_nonquant_param_into_quant_net(net, param_dict)
load_nonquant_param_into_quant_net(net, param_dict, ['step'])
epoch_size = config.epoch_size - config.pretrained_epoch_size
else:
for _, cell in net.cells_and_names():

Loading…
Cancel
Save