You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
299 lines
11 KiB
299 lines
11 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Quantization utils."""
|
|
|
|
import numpy as np
|
|
|
|
|
|
def cal_quantization_params(input_min,
|
|
input_max,
|
|
data_type,
|
|
num_bits=8,
|
|
symmetric=False,
|
|
narrow_range=False):
|
|
r"""
|
|
Calculate quantization params for scale and zero point.
|
|
|
|
Args:
|
|
input_min (numpy.ndarray): The dimension of channel or 1.
|
|
input_max (numpy.ndarray): The dimension of channel or 1.
|
|
data_type (numpy type) : Can be numpy int8, numpy uint8.
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
symmetric (bool): Whether the quantization algorithm is symmetric or not. Default: False.
|
|
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
|
|
|
Returns:
|
|
scale (numpy.ndarray): quantization param.
|
|
zero point (numpy.ndarray): quantization param.
|
|
"""
|
|
input_max = np.maximum(0.0, input_max)
|
|
input_min = np.minimum(0.0, input_min)
|
|
|
|
if input_min.shape != input_max.shape:
|
|
raise ValueError("input min shape should equal to input max.")
|
|
if len(input_min.shape) > 1:
|
|
raise ValueError("input min and max shape should be one dim.")
|
|
if (input_min > input_max).all():
|
|
raise ValueError("input_min min should less than input max.")
|
|
if (input_max == input_min).all():
|
|
return np.ones(input_min.shape), np.zeros(input_min.shape)
|
|
|
|
if data_type == np.int8:
|
|
quant_min = 0 - 2 ** (num_bits - 1)
|
|
quant_max = 2 ** (num_bits - 1) - 1
|
|
elif data_type == np.uint8:
|
|
quant_min = 0
|
|
quant_max = 2 ** num_bits - 1
|
|
else:
|
|
raise ValueError("Unsupported datatype({})".format(data_type))
|
|
if narrow_range:
|
|
quant_min = quant_min + 1
|
|
|
|
# calculate scale
|
|
if symmetric:
|
|
input_max = np.maximum(-input_min, input_max)
|
|
input_min = -input_max
|
|
scale = (input_max - input_min) / (quant_max - quant_min)
|
|
|
|
# calculate zero point
|
|
if symmetric:
|
|
zp = np.zeros(input_min.shape)
|
|
else:
|
|
zp_double = quant_min - input_min / scale
|
|
zp = np.floor(zp_double + 0.5)
|
|
|
|
return scale, zp
|
|
|
|
|
|
def weight2int(data, scale, zero_point, data_type, num_bits=8, narrow_range=False):
|
|
r"""
|
|
Calculate int8/uint8 weight from fp32. the formula is defined as:
|
|
|
|
.. math::
|
|
int8/uint8 = round(float/scale) + offset
|
|
|
|
Args:
|
|
data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
|
|
scale (numpy.ndarray): The dimension of channel or 1.
|
|
zero_point (numpy.ndarray): The dimension of channel or 1.
|
|
data_type (numpy type) : Can be numpy int8, numpy uint8.
|
|
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
|
|
narrow_range (bool): Whether the quantization algorithm uses narrow range or not. Default: False.
|
|
|
|
Returns:
|
|
weight (numpy.ndarray): The dimension of channel or 1.
|
|
"""
|
|
if scale.shape != zero_point.shape:
|
|
raise ValueError("`scale` and `zero_point` should have the same shape.")
|
|
if scale.shape[0] < 0:
|
|
raise ValueError("`scale` and `zero_point` shape should greater than zero.")
|
|
if len(scale.shape) >= 1 and scale.shape[0] > 1:
|
|
# for perchannel
|
|
if scale.shape[0] == data.shape[0]:
|
|
# `Conv2d` or `Dense` op weight
|
|
shape_list = [-1] + [1] * len(data.shape[1:])
|
|
scale = scale.reshape(shape_list)
|
|
zero_point = zero_point.reshape(shape_list)
|
|
elif scale.shape[0] == data.shape[1]:
|
|
# `DepthwiseConv2d` op weight
|
|
shape_list = [1, -1] + [1] * len(data.shape[2:])
|
|
scale = scale.reshape(shape_list)
|
|
zero_point = zero_point.reshape(shape_list)
|
|
else:
|
|
raise ValueError("Unsupported weight shape({})".format(data.shape))
|
|
|
|
if data_type == np.int8:
|
|
quant_min = 0 - 2 ** (num_bits - 1)
|
|
quant_max = 2 ** (num_bits - 1) - 1
|
|
elif data_type == np.uint8:
|
|
quant_min = 0
|
|
quant_max = 2 ** num_bits - 1
|
|
else:
|
|
raise ValueError("Unsupported weight datatype({})".format(data_type))
|
|
if narrow_range:
|
|
quant_min = quant_min + 1
|
|
|
|
weight_int = np.round((data / scale) + zero_point)
|
|
weight_int[weight_int > quant_max] = quant_max
|
|
weight_int[weight_int < quant_min] = quant_min
|
|
return weight_int
|
|
|
|
def scale_zp_max_min_from_fake_quant_cell(cell, data_type):
|
|
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
|
|
minq = cell.minq.data.asnumpy()
|
|
maxq = cell.maxq.data.asnumpy()
|
|
op = cell.fake_quant_infer
|
|
|
|
scale, zp = cal_quantization_params(
|
|
minq, maxq, data_type,
|
|
num_bits=op.num_bits,
|
|
symmetric=op.symmetric,
|
|
narrow_range=op.narrow_range)
|
|
return scale, zp, maxq, minq
|
|
|
|
|
|
def scale_zp_from_data(op, minq, maxq, data_type):
|
|
r"""
|
|
Get calculate quantization params for scale and zero point.
|
|
|
|
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
|
|
|
|
Args:
|
|
op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
|
|
`mindspore.ops.operation.FakeQuantPerChannel`
|
|
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
|
|
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
|
|
data_type (numpy type): Can be `numpy.int8` or `numpy.uint8`.
|
|
|
|
Returns:
|
|
scale (numpy.ndarray): quantization param.
|
|
zero point (numpy.ndarray): quantization param.
|
|
"""
|
|
minq = minq.data.asnumpy()
|
|
maxq = maxq.data.asnumpy()
|
|
|
|
scale, zp = cal_quantization_params(
|
|
minq, maxq, data_type,
|
|
num_bits=op.num_bits,
|
|
symmetric=op.symmetric,
|
|
narrow_range=op.narrow_range)
|
|
return scale, zp
|
|
|
|
|
|
def scale_zp_max_min_from_data(op, minq, maxq, data_type):
|
|
"""Get calculate quantization params for scale, zero point, max and min."""
|
|
minq = minq.data.asnumpy()
|
|
maxq = maxq.data.asnumpy()
|
|
|
|
scale, zp = cal_quantization_params(
|
|
minq, maxq, data_type,
|
|
num_bits=op.num_bits,
|
|
symmetric=op.symmetric,
|
|
narrow_range=op.narrow_range)
|
|
return scale, zp, maxq, minq
|
|
|
|
|
|
def fold_batchnorm(weight, cell_quant):
|
|
r"""
|
|
Fold the batchnorm in `Conv2dBnFoldQuant` to weight.
|
|
|
|
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
|
|
|
|
Args:
|
|
weight (numpy.ndarray): Weight of `cell_quant`.
|
|
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnFoldQuant`.
|
|
|
|
Returns:
|
|
weight (numpy.ndarray): Folded weight.
|
|
bias (numpy.ndarray): Folded bias.
|
|
"""
|
|
variance = cell_quant.moving_variance.data.asnumpy()
|
|
mean = cell_quant.moving_mean.data.asnumpy()
|
|
gamma = cell_quant.gamma.data.asnumpy()
|
|
beta = cell_quant.beta.data.asnumpy()
|
|
epsilon = cell_quant.eps
|
|
sigma = np.sqrt(variance + epsilon)
|
|
|
|
if gamma.shape[0] == weight.shape[0]:
|
|
# `Conv2d` or `Dense` op weight
|
|
shape_list = [-1] + [1] * len(weight.shape[1:])
|
|
_gamma = gamma.reshape(shape_list)
|
|
_sigma = sigma.reshape(shape_list)
|
|
elif gamma.shape[0] == weight.shape[1]:
|
|
# `DepthwiseConv2d` op weight
|
|
shape_list = [1, -1] + [1] * len(weight.shape[2:])
|
|
_gamma = gamma.reshape(shape_list)
|
|
_sigma = sigma.reshape(shape_list)
|
|
else:
|
|
raise ValueError("Unsupported weight shape({})".format(weight.shape))
|
|
|
|
weight = weight * _gamma / _sigma
|
|
bias = beta - gamma * mean / sigma
|
|
return weight, bias
|
|
|
|
|
|
def without_fold_batchnorm(weight, cell_quant):
|
|
r"""
|
|
Fold the batchnorm in `Conv2dBnWithoutFoldQuant` to weight.
|
|
|
|
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
|
|
|
|
Args:
|
|
weight (numpy.ndarray): Weight of `cell_quant`.
|
|
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBnWithoutFoldQuant`.
|
|
|
|
Returns:
|
|
weight (numpy.ndarray): whihout folded weight.
|
|
bias (numpy.ndarray): without folded bias.
|
|
"""
|
|
variance = cell_quant.batchnorm.moving_variance.data.asnumpy()
|
|
mean = cell_quant.batchnorm.moving_mean.data.asnumpy()
|
|
gamma = cell_quant.batchnorm.gamma.data.asnumpy()
|
|
beta = cell_quant.batchnorm.beta.data.asnumpy()
|
|
epsilon = cell_quant.batchnorm.eps
|
|
sigma = np.sqrt(variance + epsilon)
|
|
|
|
if gamma.shape[0] == weight.shape[0]:
|
|
# `Conv2d` or `Dense` op weight
|
|
shape_list = [-1] + [1] * len(weight.shape[1:])
|
|
_gamma = gamma.reshape(shape_list)
|
|
_sigma = sigma.reshape(shape_list)
|
|
elif gamma.shape[0] == weight.shape[1]:
|
|
# `DepthwiseConv2d` op weight
|
|
shape_list = [1, -1] + [1] * len(weight.shape[2:])
|
|
_gamma = gamma.reshape(shape_list)
|
|
_sigma = sigma.reshape(shape_list)
|
|
else:
|
|
raise ValueError("Unsupported weight shape({})".format(weight.shape))
|
|
|
|
weight = weight * _gamma / _sigma
|
|
bias = beta - gamma * mean / sigma
|
|
return weight, bias
|
|
|
|
|
|
def load_nonquant_param_into_quant_net(quant_model, params_dict, quant_new_params=None):
|
|
"""
|
|
load fp32 model parameters to quantization model.
|
|
|
|
Args:
|
|
quant_model: quantization model.
|
|
params_dict: f32 param.
|
|
quant_new_params:parameters that exist in quantative network but not in unquantative network.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
iterable_dict = {
|
|
'weight': iter([item for item in params_dict.items() if item[0].endswith('weight')]),
|
|
'bias': iter([item for item in params_dict.items() if item[0].endswith('bias')]),
|
|
'gamma': iter([item for item in params_dict.items() if item[0].endswith('gamma')]),
|
|
'beta': iter([item for item in params_dict.items() if item[0].endswith('beta')]),
|
|
'moving_mean': iter([item for item in params_dict.items() if item[0].endswith('moving_mean')]),
|
|
'moving_variance': iter(
|
|
[item for item in params_dict.items() if item[0].endswith('moving_variance')]),
|
|
'minq': iter([item for item in params_dict.items() if item[0].endswith('minq')]),
|
|
'maxq': iter([item for item in params_dict.items() if item[0].endswith('maxq')])
|
|
}
|
|
for name, param in quant_model.parameters_and_names():
|
|
key_name = name.split(".")[-1]
|
|
if key_name not in iterable_dict.keys():
|
|
if quant_new_params is not None and key_name in quant_new_params:
|
|
continue
|
|
raise ValueError(f"Can't find match parameter in ckpt,param name = {name}")
|
|
value_param = next(iterable_dict[key_name], None)
|
|
if value_param is not None:
|
|
param.set_data(value_param[1].data)
|
|
print(f'init model param {name} with checkpoint param {value_param[0]}')
|