!14325 modify repeatedly export quant net mindir

From: @changzherui
Reviewed-by: @zhoufeng54,@zhunaipan
Signed-off-by: @zhunaipan
pull/14325/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9b0e873c33

@ -18,6 +18,8 @@ import sys
import stat import stat
import math import math
import shutil import shutil
import time
import copy
from threading import Thread, Lock from threading import Thread, Lock
import numpy as np import numpy as np
@ -756,6 +758,9 @@ def _quant_export(network, *inputs, file_format, **kwargs):
supported_formats = ['AIR', 'MINDIR'] supported_formats = ['AIR', 'MINDIR']
quant_mode_formats = ['AUTO', 'MANUAL'] quant_mode_formats = ['AUTO', 'MANUAL']
quant_net = copy.deepcopy(network)
quant_net._create_time = int(time.time() * 1e9)
mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean'] mean = 127.5 if kwargs.get('mean', None) is None else kwargs['mean']
std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev'] std_dev = 127.5 if kwargs.get('std_dev', None) is None else kwargs['std_dev']
@ -772,17 +777,17 @@ def _quant_export(network, *inputs, file_format, **kwargs):
if file_format not in supported_formats: if file_format not in supported_formats:
raise ValueError('Illegal file format {}.'.format(file_format)) raise ValueError('Illegal file format {}.'.format(file_format))
network.set_train(False) quant_net.set_train(False)
if file_format == "MINDIR": if file_format == "MINDIR":
if quant_mode == 'MANUAL': if quant_mode == 'MANUAL':
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
else: else:
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs, is_mindir=True)
else: else:
if quant_mode == 'MANUAL': if quant_mode == 'MANUAL':
exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs) exporter = quant_export.ExportManualQuantNetwork(quant_net, mean, std_dev, *inputs)
else: else:
exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) exporter = quant_export.ExportToQuantInferNetwork(quant_net, mean, std_dev, *inputs)
deploy_net = exporter.run() deploy_net = exporter.run()
return deploy_net return deploy_net

Loading…
Cancel
Save