|
|
@ -32,7 +32,8 @@ from mindspore.communication.management import init
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.nn as nn
|
|
|
|
import mindspore.common.initializer as weight_init
|
|
|
|
import mindspore.common.initializer as weight_init
|
|
|
|
|
|
|
|
|
|
|
|
from models.resnet_quant import resnet50_quant
|
|
|
|
#from models.resnet_quant import resnet50_quant #auto construct quantative network of resnet50
|
|
|
|
|
|
|
|
from models.resnet_quant_manual import resnet50_quant #manually construct quantative network of resnet50
|
|
|
|
from src.dataset import create_dataset
|
|
|
|
from src.dataset import create_dataset
|
|
|
|
from src.lr_generator import get_lr
|
|
|
|
from src.lr_generator import get_lr
|
|
|
|
from src.config import config_quant
|
|
|
|
from src.config import config_quant
|
|
|
@ -86,7 +87,7 @@ if __name__ == '__main__':
|
|
|
|
# weight init and load checkpoint file
|
|
|
|
# weight init and load checkpoint file
|
|
|
|
if args_opt.pre_trained:
|
|
|
|
if args_opt.pre_trained:
|
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained)
|
|
|
|
param_dict = load_checkpoint(args_opt.pre_trained)
|
|
|
|
load_nonquant_param_into_quant_net(net, param_dict)
|
|
|
|
load_nonquant_param_into_quant_net(net, param_dict, ['step'])
|
|
|
|
epoch_size = config.epoch_size - config.pretrained_epoch_size
|
|
|
|
epoch_size = config.epoch_size - config.pretrained_epoch_size
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
for _, cell in net.cells_and_names():
|
|
|
|
for _, cell in net.cells_and_names():
|
|
|
|