|
|
|
@ -12,6 +12,8 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import logging
|
|
|
|
|
import numpy as np
|
|
|
|
|
from ....executor import global_scope
|
|
|
|
@ -43,7 +45,9 @@ class PostTrainingQuantization(object):
|
|
|
|
|
scope=None,
|
|
|
|
|
algo="KL",
|
|
|
|
|
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
|
|
|
|
|
is_full_quantize=False):
|
|
|
|
|
is_full_quantize=False,
|
|
|
|
|
is_use_cache_file=False,
|
|
|
|
|
cache_dir="./temp_post_training"):
|
|
|
|
|
'''
|
|
|
|
|
The class utilizes post training quantization methon to quantize the
|
|
|
|
|
fp32 model. It uses calibrate data to calculate the scale factor of
|
|
|
|
@ -78,9 +82,16 @@ class PostTrainingQuantization(object):
|
|
|
|
|
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
|
|
|
|
|
"mul"].
|
|
|
|
|
is_full_quantized(bool, optional): If set is_full_quantized as True,
|
|
|
|
|
apply quantization to all supported quantizable op type. If set
|
|
|
|
|
apply quantization to all supported quantizable op type. If set
|
|
|
|
|
is_full_quantized as False, only apply quantization to the op type
|
|
|
|
|
according to the input quantizable_op_type.
|
|
|
|
|
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
|
|
|
|
|
all temp data will be saved in memory. If set is_use_cache_file as True,
|
|
|
|
|
it will save temp data to disk. When the fp32 model is complex or
|
|
|
|
|
the number of calibrate data is large, we should set is_use_cache_file
|
|
|
|
|
as True. Defalut is False.
|
|
|
|
|
cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
|
|
|
|
|
the directory for saving temp data. Default is ./temp_post_training.
|
|
|
|
|
Returns:
|
|
|
|
|
None
|
|
|
|
|
|
|
|
|
@ -129,6 +140,10 @@ class PostTrainingQuantization(object):
|
|
|
|
|
self._batch_nums = batch_nums
|
|
|
|
|
self._scope = global_scope() if scope == None else scope
|
|
|
|
|
self._algo = algo
|
|
|
|
|
self._is_use_cache_file = is_use_cache_file
|
|
|
|
|
self._cache_dir = cache_dir
|
|
|
|
|
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
|
|
|
|
|
os.mkdir(self._cache_dir)
|
|
|
|
|
|
|
|
|
|
supported_quantizable_op_type = \
|
|
|
|
|
QuantizationTransformPass._supported_quantizable_op_type + \
|
|
|
|
@ -150,8 +165,8 @@ class PostTrainingQuantization(object):
|
|
|
|
|
|
|
|
|
|
self._op_real_in_out_name = _op_real_in_out_name
|
|
|
|
|
self._bit_length = 8
|
|
|
|
|
self._quantized_weight_var_name = []
|
|
|
|
|
self._quantized_act_var_name = []
|
|
|
|
|
self._quantized_weight_var_name = set()
|
|
|
|
|
self._quantized_act_var_name = set()
|
|
|
|
|
self._sampling_data = {}
|
|
|
|
|
self._quantized_var_scale_factor = {}
|
|
|
|
|
|
|
|
|
@ -174,7 +189,8 @@ class PostTrainingQuantization(object):
|
|
|
|
|
feed=data,
|
|
|
|
|
fetch_list=self._fetch_list,
|
|
|
|
|
return_numpy=False)
|
|
|
|
|
self._sample_data()
|
|
|
|
|
self._sample_data(batch_id)
|
|
|
|
|
|
|
|
|
|
if batch_id % 5 == 0:
|
|
|
|
|
_logger.info("run batch: " + str(batch_id))
|
|
|
|
|
batch_id += 1
|
|
|
|
@ -238,10 +254,9 @@ class PostTrainingQuantization(object):
|
|
|
|
|
op_type = op.type
|
|
|
|
|
if op_type in self._quantizable_op_type:
|
|
|
|
|
if op_type in ("conv2d", "depthwise_conv2d"):
|
|
|
|
|
self._quantized_act_var_name.append(op.input("Input")[0])
|
|
|
|
|
self._quantized_weight_var_name.append(
|
|
|
|
|
op.input("Filter")[0])
|
|
|
|
|
self._quantized_act_var_name.append(op.output("Output")[0])
|
|
|
|
|
self._quantized_act_var_name.add(op.input("Input")[0])
|
|
|
|
|
self._quantized_weight_var_name.add(op.input("Filter")[0])
|
|
|
|
|
self._quantized_act_var_name.add(op.output("Output")[0])
|
|
|
|
|
elif op_type == "mul":
|
|
|
|
|
if self._is_input_all_not_persistable(
|
|
|
|
|
op, persistable_var_names):
|
|
|
|
@ -249,9 +264,9 @@ class PostTrainingQuantization(object):
|
|
|
|
|
_logger.warning("Skip quant a mul op for two "
|
|
|
|
|
"input variables are not persistable")
|
|
|
|
|
else:
|
|
|
|
|
self._quantized_act_var_name.append(op.input("X")[0])
|
|
|
|
|
self._quantized_weight_var_name.append(op.input("Y")[0])
|
|
|
|
|
self._quantized_act_var_name.append(op.output("Out")[0])
|
|
|
|
|
self._quantized_act_var_name.add(op.input("X")[0])
|
|
|
|
|
self._quantized_weight_var_name.add(op.input("Y")[0])
|
|
|
|
|
self._quantized_act_var_name.add(op.output("Out")[0])
|
|
|
|
|
else:
|
|
|
|
|
# process other quantizable op type, the input must all not persistable
|
|
|
|
|
if self._is_input_all_not_persistable(
|
|
|
|
@ -260,10 +275,10 @@ class PostTrainingQuantization(object):
|
|
|
|
|
op_type]
|
|
|
|
|
for input_name in input_output_name_list[0]:
|
|
|
|
|
for var_name in op.input(input_name):
|
|
|
|
|
self._quantized_act_var_name.append(var_name)
|
|
|
|
|
self._quantized_act_var_name.add(var_name)
|
|
|
|
|
for output_name in input_output_name_list[1]:
|
|
|
|
|
for var_name in op.output(output_name):
|
|
|
|
|
self._quantized_act_var_name.append(var_name)
|
|
|
|
|
self._quantized_act_var_name.add(var_name)
|
|
|
|
|
|
|
|
|
|
# set activation variables to be persistable, so can obtain
|
|
|
|
|
# the tensor data in sample_data
|
|
|
|
@ -271,7 +286,7 @@ class PostTrainingQuantization(object):
|
|
|
|
|
if var.name in self._quantized_act_var_name:
|
|
|
|
|
var.persistable = True
|
|
|
|
|
|
|
|
|
|
def _sample_data(self):
|
|
|
|
|
def _sample_data(self, iter):
|
|
|
|
|
'''
|
|
|
|
|
Sample the tensor data of quantized variables,
|
|
|
|
|
applied in every iteration.
|
|
|
|
@ -281,11 +296,20 @@ class PostTrainingQuantization(object):
|
|
|
|
|
var_tensor = self._load_var_value(var_name)
|
|
|
|
|
self._sampling_data[var_name] = var_tensor
|
|
|
|
|
|
|
|
|
|
for var_name in self._quantized_act_var_name:
|
|
|
|
|
if var_name not in self._sampling_data:
|
|
|
|
|
self._sampling_data[var_name] = []
|
|
|
|
|
var_tensor = self._load_var_value(var_name)
|
|
|
|
|
self._sampling_data[var_name].append(var_tensor)
|
|
|
|
|
if self._is_use_cache_file:
|
|
|
|
|
for var_name in self._quantized_act_var_name:
|
|
|
|
|
var_tensor = self._load_var_value(var_name)
|
|
|
|
|
var_tensor = var_tensor.ravel()
|
|
|
|
|
save_path = os.path.join(self._cache_dir,
|
|
|
|
|
var_name + "_" + str(iter) + ".npy")
|
|
|
|
|
np.save(save_path, var_tensor)
|
|
|
|
|
else:
|
|
|
|
|
for var_name in self._quantized_act_var_name:
|
|
|
|
|
if var_name not in self._sampling_data:
|
|
|
|
|
self._sampling_data[var_name] = []
|
|
|
|
|
var_tensor = self._load_var_value(var_name)
|
|
|
|
|
var_tensor = var_tensor.ravel()
|
|
|
|
|
self._sampling_data[var_name].append(var_tensor)
|
|
|
|
|
|
|
|
|
|
def _calculate_scale_factor(self):
|
|
|
|
|
'''
|
|
|
|
@ -302,13 +326,33 @@ class PostTrainingQuantization(object):
|
|
|
|
|
var_name] = scale_factor_per_channel
|
|
|
|
|
|
|
|
|
|
# apply kl quantization for activation
|
|
|
|
|
for var_name in self._quantized_act_var_name:
|
|
|
|
|
if self._algo == "KL":
|
|
|
|
|
self._quantized_var_scale_factor[var_name] = \
|
|
|
|
|
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
|
|
|
|
|
else:
|
|
|
|
|
self._quantized_var_scale_factor[var_name] = \
|
|
|
|
|
np.max(np.abs(self._sampling_data[var_name]))
|
|
|
|
|
if self._is_use_cache_file:
|
|
|
|
|
for var_name in self._quantized_act_var_name:
|
|
|
|
|
sampling_data = []
|
|
|
|
|
filenames = [f for f in os.listdir(self._cache_dir) \
|
|
|
|
|
if re.match(var_name + '_[0-9]+.npy', f)]
|
|
|
|
|
for filename in filenames:
|
|
|
|
|
file_path = os.path.join(self._cache_dir, filename)
|
|
|
|
|
sampling_data.append(np.load(file_path))
|
|
|
|
|
os.remove(file_path)
|
|
|
|
|
sampling_data = np.concatenate(sampling_data)
|
|
|
|
|
|
|
|
|
|
if self._algo == "KL":
|
|
|
|
|
self._quantized_var_scale_factor[var_name] = \
|
|
|
|
|
self._get_kl_scaling_factor(np.abs(sampling_data))
|
|
|
|
|
else:
|
|
|
|
|
self._quantized_var_scale_factor[var_name] = \
|
|
|
|
|
np.max(np.abs(sampling_data))
|
|
|
|
|
else:
|
|
|
|
|
for var_name in self._quantized_act_var_name:
|
|
|
|
|
self._sampling_data[var_name] = np.concatenate(
|
|
|
|
|
self._sampling_data[var_name])
|
|
|
|
|
if self._algo == "KL":
|
|
|
|
|
self._quantized_var_scale_factor[var_name] = \
|
|
|
|
|
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
|
|
|
|
|
else:
|
|
|
|
|
self._quantized_var_scale_factor[var_name] = \
|
|
|
|
|
np.max(np.abs(self._sampling_data[var_name]))
|
|
|
|
|
|
|
|
|
|
def _update_program(self):
|
|
|
|
|
'''
|
|
|
|
|