!2161 [qat]Export network from quantization aware network to deploy

Merge pull request !2161 from vlne-v1/I1IZV3-quant-infer
pull/2161/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 9958bc479a

@ -487,6 +487,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
}));
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
.def(py::pickle(
[](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 2) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>());
return tensor;
}))
.def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_)
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.");

@ -220,6 +220,8 @@ const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared<Primitive>("FakeQuantPerLayer");
const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared<Primitive>("FakeQuantPerChannel");
// Other miscellaneous
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");

@ -228,6 +228,8 @@ extern const PrimitivePtr kPrimActivation;
extern const PrimitivePtr kPrimZerosLike;
extern const PrimitivePtr kPrimFakeBprop;
extern const PrimitivePtr kPrimBpropCut;
extern const PrimitivePtr kPrimFakeQuantPerLayer;
extern const PrimitivePtr kPrimFakeQuantPerChannel;
// Other Miscellaneous
extern const PrimitivePtr kPrimIdentity;

@ -77,6 +77,8 @@ PYBIND11_MODULE(_c_expression, m) {
"Get CNode Strategy Dictionary.")
.def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"),
"Get Allreduce Fusion Dictionary.")
.def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"),
"Fetch the inputs of Conv or Matmul for quant export.")
.def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"),
py::arg("broadcast_params") = py::dict(), "Build data graph.")
.def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.")

@ -281,6 +281,75 @@ ExecutorPy::~ExecutorPy() {
ConfigManager::GetInstance().ResetConfig();
}
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchInfoForQuantExport(
const std::string &phase_s) {
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> fake_quant_table;
auto filter = [](AnfNodePtr node) {
return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul));
};
std::vector<AnfNodePtr> nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter);
auto is_quant_cnode = [](AnfNodePtr node) {
return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) ||
IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel);
};
for (auto node : nodes) {
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() != 3) {
continue;
}
auto x = cnode->input(1);
auto weight = cnode->input(2);
if (!is_quant_cnode(weight)) {
continue;
}
// get parameter weight's name
cnode = weight->cast<CNodePtr>();
auto weight_node = cnode->input(2);
if (!weight_node->isa<Parameter>()) {
continue;
}
auto weight_name = weight_node->cast<ParameterPtr>()->name();
// find the fakequant from input
int count = 0;
int max_depth = 5;
while (!is_quant_cnode(x)) {
if (count >= max_depth) {
break;
}
cnode = x->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() <= 1) {
break;
}
x = cnode->input(1);
count += 1;
}
// get the fakequant parameter minq's name
if (!is_quant_cnode(x)) {
continue;
}
cnode = x->cast<CNodePtr>();
if (cnode == nullptr || cnode->size() != 4) {
continue;
}
auto fakequant_min_node = cnode->input(2);
if (!fakequant_min_node->isa<Parameter>()) {
continue;
}
auto fakequant_min_node_name = fakequant_min_node->cast<ParameterPtr>()->name();
auto quant_op_value = cnode->input(0)->cast<ValueNodePtr>()->value();
if (!quant_op_value->isa<PrimitivePy>()) {
continue;
}
auto quant_op = quant_op_value->cast<PrimitivePyPtr>();
fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name);
}
return fake_quant_table;
}
void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
// save the graph to ExecutorPy
FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph();

@ -97,6 +97,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
void ReleaseResource(const py::object &phase);
static void ClearRes();
std::map<std::string, std::pair<PrimitivePyPtr, std::string>> FetchInfoForQuantExport(const std::string &phase_s);
private:
ExecutorPy();
void ConvertObjectToTensors(const py::dict &dict, std::map<std::string, tensor::TensorPtr> *tensors);

@ -39,6 +39,7 @@ namespace mindspore {
enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE };
using IncludeFunc = std::function<IncludeType(const AnfNodePtr &)>;
using FilterFunc = std::function<bool(const AnfNodePtr &)>;
using SuccFunc = std::function<std::vector<AnfNodePtr>(AnfNodePtr)>;
using SearchFunc = std::function<std::vector<AnfNodePtr>(const AnfNodePtr &, const IncludeFunc &)>;
@ -58,6 +59,9 @@ std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const Incl
std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude);
std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
const FilterFunc &filter);
std::vector<AnfNodePtr> TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming,
const IncludeFunc &include = AlwaysInclude);

@ -37,7 +37,8 @@ namespace mindspore {
namespace {
class DeepFirstSearcher : public AnfVisitor {
public:
explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {}
explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
: include_(include), filter_(filter) {}
~DeepFirstSearcher() override = default;
std::vector<AnfNodePtr> Search(const AnfNodePtr &root) {
@ -61,8 +62,9 @@ class DeepFirstSearcher : public AnfVisitor {
if (incl == EXCLUDE) {
return;
}
res_.push_back(node);
if (filter_ == nullptr || !filter_(node)) {
res_.push_back(node);
}
if (incl == FOLLOW) {
AnfVisitor::Visit(node);
}
@ -71,6 +73,7 @@ class DeepFirstSearcher : public AnfVisitor {
private:
size_t seen_{0};
IncludeFunc include_;
FilterFunc filter_;
std::vector<AnfNodePtr> res_{};
};
@ -160,10 +163,16 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher {
};
} // namespace
// include for if expand the node the search, filter for if put the node to results.
std::vector<AnfNodePtr> DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
return DeepScopedGraphSearcher(include).Search(root);
}
std::vector<AnfNodePtr> DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include,
const FilterFunc &filter) {
return DeepFirstSearcher(include, filter).Search(root);
}
std::vector<AnfNodePtr> DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) {
return DeepUsedGraphSearcher(include).Search(root);
}

@ -526,6 +526,11 @@ class _Executor:
phase = 'export' + '.' + str(net.create_time)
export_graph(file_name, file_format, phase)
def fetch_info_for_quant_export(self, exec_id):
"""Get graph proto from pipeline."""
if self._executor.has_compiled(exec_id) is False:
return None
return self._executor.fetch_info_for_quant_export(exec_id)
_executor = _Executor()
_pynative_exec = _PynativeExecutor()

@ -18,8 +18,6 @@ from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
import mindspore.context as context
from mindspore._checkparam import check_bool, check_typename
from mindspore._extends import cell_attr_register
@ -85,13 +83,12 @@ class _BatchNorm(Cell):
self.reshape = P.Reshape()
self.is_ascend = context.get_context("device_target") == "Ascend"
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
self.momentum = 1.0 - momentum
if context.get_context("enable_ge"):
self.is_ge_backend = True
self.momentum = Tensor(1.0 - momentum, mstype.float32)
else:
self.is_ge_backend = False
self.momentum = 1.0 - momentum
if self.is_graph_mode and (self.is_ge_backend or self.is_ascend):
self.bn_train = P.BatchNorm(is_training=True,
epsilon=self.eps)

@ -729,8 +729,8 @@ class DenseQuant(Cell):
self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape()[1] != in_channels:
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")
self.weight = Parameter(initializer(
@ -738,7 +738,7 @@ class DenseQuant(Cell):
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")
self.bias = Parameter(initializer(
@ -780,8 +780,14 @@ class DenseQuant(Cell):
return str_info
class _QuantActivation(Cell):
r"""
Base class for Quant activation function. Add Fake Quant OP after activation OP.
"""
def get_origin(self):
raise NotImplementedError
class ReLUQuant(Cell):
class ReLUQuant(_QuantActivation):
r"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP.
@ -828,8 +834,11 @@ class ReLUQuant(Cell):
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu
class ReLU6Quant(Cell):
class ReLU6Quant(_QuantActivation):
r"""
ReLU6Quant activation function.
@ -878,8 +887,10 @@ class ReLU6Quant(Cell):
x = self.fake_quant_act(x)
return x
def get_origin(self):
return self.relu6
class HSwishQuant(Cell):
class HSwishQuant(_QuantActivation):
r"""
HSwishQuant activation function. Add Fake Quant OP after HSwish OP.
@ -935,8 +946,10 @@ class HSwishQuant(Cell):
x = self.fake_quant_act_after(x)
return x
def get_origin(self):
return self.act
class HSigmoidQuant(Cell):
class HSigmoidQuant(_QuantActivation):
r"""
HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP.
@ -991,6 +1004,8 @@ class HSigmoidQuant(Cell):
x = self.fake_quant_act_after(x)
return x
def get_origin(self):
return self.act
class TensorAddQuant(Cell):
r"""
@ -1083,3 +1098,77 @@ class MulQuant(Cell):
x = self.mul(x1, x2)
x = self.fake_quant_act(x)
return x
class QuantBlock(Cell):
r"""
A quant block of Conv/Dense, activation layer for Ascend deploy.
Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant.
Notes:
This block is only for deploy, and not trainable.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
core_op,
weight,
quant_op,
dequant_op,
dequant_scale,
bias=None,
activation=None):
super(QuantBlock, self).__init__()
self.core_op = core_op
self.weight = weight
self.quant = quant_op
self.dequant = dequant_op
self.dequant_scale = dequant_scale
self.bias = bias
self.has_bias = bias is None
self.activation = activation
self.has_act = activation is None
def construct(self, x):
x = self.quant(x)
x = self.core_op(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
if self.has_act:
x = self.activation(x)
x = self.dequant(x, self.dequant_scale)
return x
def extend_repr(self):
str_info = f'quant={self.quant}, core_op={type(self.core_op)}'
if self.has_bias:
str_info = str_info + f', bias={self.bias}'
if self.has_act:
str_info = str_info + f', activation={self.activation}'
str_info = str_info + f', dequant={self.dequant}'
return str_info

@ -584,6 +584,8 @@ class MatMul(PrimitiveWithInfer):
def infer_dtype(self, x, y):
args = {"x": x, "y": y}
validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name)
if x.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x

@ -800,7 +800,7 @@ class Conv2D(PrimitiveWithInfer):
def infer_shape(self, x_shape, w_shape):
validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name)
validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name)
@ -846,6 +846,8 @@ class Conv2D(PrimitiveWithInfer):
args = {'x': x_dtype, 'w': w_dtype}
valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32]
validator.check_tensor_type_same(args, valid_types, self.name)
if x_dtype.element_type() == mstype.int8:
return mstype.tensor_type(mstype.int32)
return x_dtype

@ -43,11 +43,12 @@ class Primitive(Primitive_):
>>> # init a Primitive obj with attr1=1 and attr2=2
>>> add = Add(attr1=1, attr2=2)
"""
_repr_ignore_list = ['input_names', 'output_names']
def __init__(self, name):
self.name = name
self.attrs = {}
self.init_attrs = {}
self.init_attrs = {"name": name}
Primitive_.__init__(self, name, self)
if hasattr(self.__class__, '__mindspore_signature__'):
sig = self._fill_signature(self.__class__.__mindspore_signature__)
@ -165,6 +166,16 @@ class Primitive(Primitive_):
def __setstate__(self, d):
self.__dict__.update(d)
def __deepcopy__(self, memo):
return type(self)(**self.init_attrs)
def __repr__(self):
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list])
info_str = f'Prim[{self.name}]'
if attr:
info_str += f'<{attr}>'
return info_str
def init_prim_io_names(self, inputs, outputs):
"""
Initializes inputs and outpus name of Tensor or attributes.
@ -185,8 +196,8 @@ class PrimitiveWithInfer(Primitive):
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describle shape
and type infer logic. The infer_value() is used for constant propogation.
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe shape
and type infer logic. The infer_value() is used for constant propagation.
Args:
name (str): Name for current Primitive.
@ -288,6 +299,7 @@ def prim_attr_register(fn):
bound_args.apply_defaults()
arguments = bound_args.arguments
del arguments['self']
del self.init_attrs['name']
for name in arguments:
value = arguments[name]
self.add_prim_attr(name, value)

File diff suppressed because it is too large Load Diff

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""quantization utils."""
"""Quantization utils."""
import numpy as np
@ -24,22 +24,19 @@ def cal_quantization_params(input_min,
symmetric=False,
narrow_range=False):
r"""
calculate quantization params for scale and zero point.
Calculate quantization params for scale and zero point.
Args:
input_min (int, list): The dimension of channel or 1.
input_max (int, list): The dimension of channel or 1.
input_min (numpy.ndarray): The dimension of channel or 1.
input_max (numpy.ndarray): The dimension of channel or 1.
data_type (numpy type) : Can ben numpy int8, numpy uint8.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
Outputs:
scale (int, list): quantization param.
zero point (int, list): quantization param.
Examples:
>>> scale, zp = cal_quantization_params([1, 2, 1], [-2, 0, -1], 8, False, False)
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
input_max = np.maximum(0.0, input_max)
input_min = np.minimum(0.0, input_min)
@ -92,27 +89,103 @@ def weight2int(data,
scale,
zero_point):
r"""
calculate int8/uint8 weight from fp32. the formula is defined as:
Calculate int8/uint8 weight from fp32. the formula is defined as:
.. math::
int8/uint8 = round(float/scale) + offset
Args:
data (int, list): The dimension of channel or 1. Should be NCHW.
scale (int, list): The dimension of channel or 1.
zero_point (int, list): The dimension of channel or 1.
data (numpy.ndarray): The dimension of channel or 1. Should be NCHW.
scale (numpy.ndarray): The dimension of channel or 1.
zero_point (numpy.ndarray): The dimension of channel or 1.
Outputs:
weight (int, list): The dimension of channel or 1.
Examples:
>>> weight = weight2int([1, 2, 1], 1, 0)
Returns:
weight (numpy.ndarray): The dimension of channel or 1.
"""
if scale.shape != zero_point.shape:
raise ValueError("scale and zero_point should have the same shape.")
if scale.shape[0] > 0:
scale = scale.reshape(1, -1, 1, 1)
zero_point = zero_point.reshape(1, -1, 1, 1)
scale = scale.reshape(1, -1)
zero_point = zero_point.reshape(1, -1)
return np.round((data/scale) + zero_point)
def scale_zp_from_fack_quant_cell(cell, data_type):
r"""
Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`.
Args:
cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq = cell.minq.data.asnumpy()
maxq = cell.maxq.data.asnumpy()
op = cell.fake_quant
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=op.num_bits,
symmetric=op.symmetric,
narrow_range=op.narrow_range)
return scale, zp
def scale_zp_from_data(op, minq, maxq, data_type):
r"""
Get calculate quantization params for scale and zero point.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or
`mindspore.ops.operation.FakeQuantPerChannel`
minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax`
maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax`
data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`.
Returns:
scale (numpy.ndarray): quantization param.
zero point (numpy.ndarray): quantization param.
"""
minq = minq.data.asnumpy()
maxq = maxq.data.asnumpy()
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=op.num_bits,
symmetric=op.symmetric,
narrow_range=op.narrow_range)
return scale, zp
def fold_batchnorm(weight, cell_quant):
r"""
Fold the batchnorm in `Conv2dBatchNormQuant` to weight.
Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive.
Args:
weight (numpy.ndarray): Weight of `cell_quant`.
cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`.
Returns:
weight (numpy.ndarray): Folded weight.
bias (numpy.ndarray): Folded bias.
"""
variance = cell_quant.moving_variance.data.asnumpy()
mean = cell_quant.moving_mean.data.asnumpy()
gamma = cell_quant.gamma.data.asnumpy()
beta = cell_quant.beta.data.asnumpy()
epsilon = cell_quant.eps
sigma = np.sqrt(variance + epsilon)
gamma = gamma.reshape(-1, 1, 1, 1)
sigma = sigma.reshape(-1, 1, 1, 1)
mean = mean.reshape(-1, 1, 1, 1)
weight = weight * gamma / sigma
bias = beta - gamma * mean / sigma
return weight, bias

@ -55,7 +55,7 @@ def init_net_param(network, init_value='ones'):
params = network.trainable_params()
for p in params:
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype()))
p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype))
class ModelCallback(Callback):
def __init__(self):

@ -13,9 +13,14 @@
# limitations under the License.
# ============================================================================
""" tests for quant """
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore import nn
from mindspore.train.quant import quant as qat
from mobilenetv2_combined import MobileNetV2
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -37,23 +42,45 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6')
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu')
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid")
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid")
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
self.fc3 = nn.DenseBnAct(84, self.num_class)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flattern = nn.Flatten()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.bn(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.max_pool2d(x)
x = self.flattern(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
def test_qat_lenet():
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
net = LeNet5()
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
# should load the checkpoint. mock here
for param in net.get_parameters():
param.init_data()
qat.export_geir(net, img, file_name="quant.pb")
@pytest.mark.skip(reason="no `te.lang.cce` in ut env")
def test_qat_mobile():
net = MobileNetV2()
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=True, freeze_bn=10000, weight_bits=8, act_bits=8)
# should load the checkpoint. mock here
for param in net.get_parameters():
param.init_data()
qat.export_geir(net, img, file_name="quant.pb")

Loading…
Cancel
Save