!5444 Support manual convert to quantative network of resnet

Merge pull request !5444 from chenfei_mindspore/add-manual-quant-network-of-resnet
pull/5444/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 135cfc6adf

@ -252,13 +252,14 @@ def without_fold_batchnorm(weight, cell_quant):
return weight, bias return weight, bias
def load_nonquant_param_into_quant_net(quant_model, params_dict): def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
""" """
load fp32 model parameters to quantization model. load fp32 model parameters to quantization model.
Args: Args:
quant_model: quantization model quant_model: quantization model.
params_dict: f32 param params_dict: f32 param.
quant_new_params:parameters that exist in quantative network but not in unquantative network.
Returns: Returns:
None None
@ -277,6 +278,8 @@ def load_nonquant_param_into_quant_net(quant_model, params_dict):
for name, param in quant_model.parameters_and_names(): for name, param in quant_model.parameters_and_names():
key_name = name.split(".")[-1] key_name = name.split(".")[-1]
if key_name not in iterable_dict.keys(): if key_name not in iterable_dict.keys():
if quant_new_params is not None and key_name in quant_new_params:
continue
raise ValueError(f"Can't find match parameter in ckpt,param name = {name}") raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
value_param = next(iterable_dict[key_name], None) value_param = next(iterable_dict[key_name], None)
if value_param is not None: if value_param is not None:

@ -20,7 +20,8 @@ import argparse
from src.config import config_quant from src.config import config_quant
from src.dataset import create_dataset from src.dataset import create_dataset
from src.crossentropy import CrossEntropy from src.crossentropy import CrossEntropy
from models.resnet_quant import resnet50_quant #from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model

@ -209,7 +209,7 @@ class ResNet(nn.Cell):
return out return out
def resnet50_quant(class_num=10001): def resnet50_quant(class_num=10):
""" """
Get ResNet50 neural network. Get ResNet50 neural network.

@ -32,7 +32,8 @@ from mindspore.communication.management import init
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as weight_init import mindspore.common.initializer as weight_init
from models.resnet_quant import resnet50_quant #from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
from src.dataset import create_dataset from src.dataset import create_dataset
from src.lr_generator import get_lr from src.lr_generator import get_lr
from src.config import config_quant from src.config import config_quant
@ -86,7 +87,7 @@ if __name__ == '__main__':
# weight init and load checkpoint file # weight init and load checkpoint file
if args_opt.pre_trained: if args_opt.pre_trained:
param_dict = load_checkpoint(args_opt.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_nonquant_param_into_quant_net(net, param_dict) load_nonquant_param_into_quant_net(net, param_dict, ['step'])
epoch_size = config.epoch_size - config.pretrained_epoch_size epoch_size = config.epoch_size - config.pretrained_epoch_size
else: else:
for _, cell in net.cells_and_names(): for _, cell in net.cells_and_names():

Loading…
Cancel
Save