From 1d77bf86a96bbd486d794771aaadfd0bf99af597 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Thu, 18 Jun 2020 20:39:58 +0800 Subject: [PATCH] Exports MindSpore quant predict model to deploy with GEIR --- mindspore/ccsrc/ir/tensor.cc | 13 ++ mindspore/ccsrc/operator/ops.cc | 2 + mindspore/ccsrc/operator/ops.h | 2 + mindspore/ccsrc/pipeline/init.cc | 2 + mindspore/ccsrc/pipeline/pipeline.cc | 69 +++++++ mindspore/ccsrc/pipeline/pipeline.h | 2 + mindspore/ccsrc/utils/graph_utils.h | 4 + mindspore/ccsrc/utils/graph_utils_extends.cc | 15 +- mindspore/common/api.py | 5 + mindspore/nn/layer/normalization.py | 7 +- mindspore/nn/layer/quant.py | 103 +++++++++- mindspore/ops/operations/math_ops.py | 2 + mindspore/ops/operations/nn_ops.py | 4 +- mindspore/ops/primitive.py | 18 +- mindspore/train/quant/quant.py | 189 ++++++++++++++++-- mindspore/train/quant/quant_utils.py | 117 +++++++++-- .../st/model_zoo_tests/yolov3/test_yolov3.py | 2 +- tests/ut/python/train/quant/test_quant.py | 41 +++- 18 files changed, 529 insertions(+), 68 deletions(-) diff --git a/mindspore/ccsrc/ir/tensor.cc b/mindspore/ccsrc/ir/tensor.cc index 53a4ef49d5..4e2e996bac 100644 --- a/mindspore/ccsrc/ir/tensor.cc +++ b/mindspore/ccsrc/ir/tensor.cc @@ -487,6 +487,19 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { })); (void)py::class_>(*m, "MetaTensor") .def(py::init>(), py::arg("dtype"), py::arg("shape")) + .def(py::pickle( + [](const MetaTensor &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(static_cast(t.data_type()), t.shape()); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 2) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + MetaTensor tensor(TypeId(t[0].cast()), t[1].cast>()); + return tensor; + })) .def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_) .def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.") .def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape."); diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index 3d93a33356..cae61f64d0 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -220,6 +220,8 @@ const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); +const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); +const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); // Other miscellaneous const PrimitivePtr kPrimIdentity = std::make_shared("identity"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index c6bea7fe7a..3b9ac01089 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -228,6 +228,8 @@ extern const PrimitivePtr kPrimActivation; extern const PrimitivePtr kPrimZerosLike; extern const PrimitivePtr kPrimFakeBprop; extern const PrimitivePtr kPrimBpropCut; +extern const PrimitivePtr kPrimFakeQuantPerLayer; +extern const PrimitivePtr kPrimFakeQuantPerChannel; // Other Miscellaneous extern const PrimitivePtr kPrimIdentity; diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 998c530cf8..b7c964958c 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -77,6 +77,8 @@ PYBIND11_MODULE(_c_expression, m) { "Get CNode Strategy Dictionary.") .def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"), "Get Allreduce Fusion Dictionary.") + .def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"), + "Fetch the inputs of Conv or Matmul for quant export.") .def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"), py::arg("broadcast_params") = py::dict(), "Build data graph.") .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.") diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 0f985e2607..f225b7cf98 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -281,6 +281,75 @@ ExecutorPy::~ExecutorPy() { ConfigManager::GetInstance().ResetConfig(); } +std::map> ExecutorPy::FetchInfoForQuantExport( + const std::string &phase_s) { + FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + MS_LOG(DEBUG) << "FetchInfoForQuantExport func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; + std::map> fake_quant_table; + auto filter = [](AnfNodePtr node) { + return !(IsPrimitiveCNode(node, prim::kPrimConv2D) || IsPrimitiveCNode(node, prim::kPrimMatMul)); + }; + std::vector nodes = DeepScopedGraphSearchWithFilter(func_graph->get_return(), AlwaysInclude, filter); + auto is_quant_cnode = [](AnfNodePtr node) { + return IsPrimitiveCNode(node, prim::kPrimFakeQuantPerLayer) || + IsPrimitiveCNode(node, prim::kPrimFakeQuantPerChannel); + }; + for (auto node : nodes) { + auto cnode = node->cast(); + if (cnode == nullptr || cnode->size() != 3) { + continue; + } + auto x = cnode->input(1); + auto weight = cnode->input(2); + if (!is_quant_cnode(weight)) { + continue; + } + // get parameter weight's name + cnode = weight->cast(); + auto weight_node = cnode->input(2); + if (!weight_node->isa()) { + continue; + } + auto weight_name = weight_node->cast()->name(); + // find the fakequant from input + int count = 0; + int max_depth = 5; + while (!is_quant_cnode(x)) { + if (count >= max_depth) { + break; + } + cnode = x->cast(); + if (cnode == nullptr || cnode->size() <= 1) { + break; + } + x = cnode->input(1); + count += 1; + } + // get the fakequant parameter minq's name + if (!is_quant_cnode(x)) { + continue; + } + cnode = x->cast(); + if (cnode == nullptr || cnode->size() != 4) { + continue; + } + auto fakequant_min_node = cnode->input(2); + if (!fakequant_min_node->isa()) { + continue; + } + auto fakequant_min_node_name = fakequant_min_node->cast()->name(); + auto quant_op_value = cnode->input(0)->cast()->value(); + if (!quant_op_value->isa()) { + continue; + } + auto quant_op = quant_op_value->cast(); + fake_quant_table[weight_name] = std::make_pair(quant_op, fakequant_min_node_name); + } + + return fake_quant_table; +} + void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { // save the graph to ExecutorPy FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h index 633ff78c0b..3f1274c417 100644 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ b/mindspore/ccsrc/pipeline/pipeline.h @@ -97,6 +97,8 @@ class ExecutorPy : public std::enable_shared_from_this { void ReleaseResource(const py::object &phase); static void ClearRes(); + std::map> FetchInfoForQuantExport(const std::string &phase_s); + private: ExecutorPy(); void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h index 0b49615523..e2703a2877 100644 --- a/mindspore/ccsrc/utils/graph_utils.h +++ b/mindspore/ccsrc/utils/graph_utils.h @@ -39,6 +39,7 @@ namespace mindspore { enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; using IncludeFunc = std::function; +using FilterFunc = std::function; using SuccFunc = std::function(AnfNodePtr)>; using SearchFunc = std::function(const AnfNodePtr &, const IncludeFunc &)>; @@ -58,6 +59,9 @@ std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const Incl std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, + const FilterFunc &filter); + std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, const IncludeFunc &include = AlwaysInclude); diff --git a/mindspore/ccsrc/utils/graph_utils_extends.cc b/mindspore/ccsrc/utils/graph_utils_extends.cc index 7c3991b638..85f9986a0d 100644 --- a/mindspore/ccsrc/utils/graph_utils_extends.cc +++ b/mindspore/ccsrc/utils/graph_utils_extends.cc @@ -37,7 +37,8 @@ namespace mindspore { namespace { class DeepFirstSearcher : public AnfVisitor { public: - explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {} + explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr) + : include_(include), filter_(filter) {} ~DeepFirstSearcher() override = default; std::vector Search(const AnfNodePtr &root) { @@ -61,8 +62,9 @@ class DeepFirstSearcher : public AnfVisitor { if (incl == EXCLUDE) { return; } - - res_.push_back(node); + if (filter_ == nullptr || !filter_(node)) { + res_.push_back(node); + } if (incl == FOLLOW) { AnfVisitor::Visit(node); } @@ -71,6 +73,7 @@ class DeepFirstSearcher : public AnfVisitor { private: size_t seen_{0}; IncludeFunc include_; + FilterFunc filter_; std::vector res_{}; }; @@ -160,10 +163,16 @@ class DeepLinkedGraphSearcher : public DeepFirstSearcher { }; } // namespace +// include for if expand the node the search, filter for if put the node to results. std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepScopedGraphSearcher(include).Search(root); } +std::vector DeepScopedGraphSearchWithFilter(const AnfNodePtr &root, const IncludeFunc &include, + const FilterFunc &filter) { + return DeepFirstSearcher(include, filter).Search(root); +} + std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepUsedGraphSearcher(include).Search(root); } diff --git a/mindspore/common/api.py b/mindspore/common/api.py index f2e122a78c..f29b931c53 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -526,6 +526,11 @@ class _Executor: phase = 'export' + '.' + str(net.create_time) export_graph(file_name, file_format, phase) + def fetch_info_for_quant_export(self, exec_id): + """Get graph proto from pipeline.""" + if self._executor.has_compiled(exec_id) is False: + return None + return self._executor.fetch_info_for_quant_export(exec_id) _executor = _Executor() _pynative_exec = _PynativeExecutor() diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 8ffb7664f0..238a3e3431 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -18,8 +18,6 @@ from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.ops.primitive import constexpr -from mindspore.common.tensor import Tensor -import mindspore.common.dtype as mstype import mindspore.context as context from mindspore._checkparam import check_bool, check_typename from mindspore._extends import cell_attr_register @@ -85,13 +83,12 @@ class _BatchNorm(Cell): self.reshape = P.Reshape() self.is_ascend = context.get_context("device_target") == "Ascend" self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE - + self.momentum = 1.0 - momentum if context.get_context("enable_ge"): self.is_ge_backend = True - self.momentum = Tensor(1.0 - momentum, mstype.float32) else: self.is_ge_backend = False - self.momentum = 1.0 - momentum + if self.is_graph_mode and (self.is_ge_backend or self.is_ascend): self.bn_train = P.BatchNorm(is_training=True, epsilon=self.eps) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 5a70066006..b67dbc60af 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -729,8 +729,8 @@ class DenseQuant(Cell): self.has_bias = check_bool(has_bias) if isinstance(weight_init, Tensor): - if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ - weight_init.shape()[1] != in_channels: + if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ + weight_init.shape[1] != in_channels: raise ValueError("weight_init shape error") self.weight = Parameter(initializer( @@ -738,7 +738,7 @@ class DenseQuant(Cell): if self.has_bias: if isinstance(bias_init, Tensor): - if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: + if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: raise ValueError("bias_init shape error") self.bias = Parameter(initializer( @@ -780,8 +780,14 @@ class DenseQuant(Cell): return str_info +class _QuantActivation(Cell): + r""" + Base class for Quant activation function. Add Fake Quant OP after activation OP. + """ + def get_origin(self): + raise NotImplementedError -class ReLUQuant(Cell): +class ReLUQuant(_QuantActivation): r""" ReLUQuant activation function. Add Fake Quant OP after Relu OP. @@ -828,8 +834,11 @@ class ReLUQuant(Cell): x = self.fake_quant_act(x) return x + def get_origin(self): + return self.relu -class ReLU6Quant(Cell): + +class ReLU6Quant(_QuantActivation): r""" ReLU6Quant activation function. @@ -878,8 +887,10 @@ class ReLU6Quant(Cell): x = self.fake_quant_act(x) return x + def get_origin(self): + return self.relu6 -class HSwishQuant(Cell): +class HSwishQuant(_QuantActivation): r""" HSwishQuant activation function. Add Fake Quant OP after HSwish OP. @@ -935,8 +946,10 @@ class HSwishQuant(Cell): x = self.fake_quant_act_after(x) return x + def get_origin(self): + return self.act -class HSigmoidQuant(Cell): +class HSigmoidQuant(_QuantActivation): r""" HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. @@ -991,6 +1004,8 @@ class HSigmoidQuant(Cell): x = self.fake_quant_act_after(x) return x + def get_origin(self): + return self.act class TensorAddQuant(Cell): r""" @@ -1083,3 +1098,77 @@ class MulQuant(Cell): x = self.mul(x1, x2) x = self.fake_quant_act(x) return x + + +class QuantBlock(Cell): + r""" + A quant block of Conv/Dense, activation layer for Ascend deploy. + + Calculate Conv or Dense in Int8, with AscendQuant and AscendDeQuant. + + Notes: + This block is only for deploy, and not trainable. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + batchnorm (bool): Specifies to used batchnorm or not. Default: None. + activation (string): Specifies activation type. The optional values are as following: + 'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid', + 'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`. + + Outputs: + Tensor of shape :math:`(N, out\_channels)`. + + Examples: + >>> net = nn.Dense(3, 4) + >>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32) + >>> net(input) + """ + + def __init__(self, + core_op, + weight, + quant_op, + dequant_op, + dequant_scale, + bias=None, + activation=None): + super(QuantBlock, self).__init__() + self.core_op = core_op + self.weight = weight + self.quant = quant_op + self.dequant = dequant_op + self.dequant_scale = dequant_scale + self.bias = bias + self.has_bias = bias is None + self.activation = activation + self.has_act = activation is None + + def construct(self, x): + x = self.quant(x) + x = self.core_op(x, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + if self.has_act: + x = self.activation(x) + x = self.dequant(x, self.dequant_scale) + return x + + def extend_repr(self): + str_info = f'quant={self.quant}, core_op={type(self.core_op)}' + if self.has_bias: + str_info = str_info + f', bias={self.bias}' + if self.has_act: + str_info = str_info + f', activation={self.activation}' + str_info = str_info + f', dequant={self.dequant}' + return str_info diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index c7cf9f7807..038a14eede 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -584,6 +584,8 @@ class MatMul(PrimitiveWithInfer): def infer_dtype(self, x, y): args = {"x": x, "y": y} validator.check_tensor_type_same(args, mstype.float_type + mstype.int_type, self.name) + if x.element_type() == mstype.int8: + return mstype.tensor_type(mstype.int32) return x diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 78c777f0f9..fb004658a5 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -800,7 +800,7 @@ class Conv2D(PrimitiveWithInfer): def infer_shape(self, x_shape, w_shape): validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) - validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check(f"x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) @@ -846,6 +846,8 @@ class Conv2D(PrimitiveWithInfer): args = {'x': x_dtype, 'w': w_dtype} valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] validator.check_tensor_type_same(args, valid_types, self.name) + if x_dtype.element_type() == mstype.int8: + return mstype.tensor_type(mstype.int32) return x_dtype diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 61da7587a1..7ceb687778 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -43,11 +43,12 @@ class Primitive(Primitive_): >>> # init a Primitive obj with attr1=1 and attr2=2 >>> add = Add(attr1=1, attr2=2) """ + _repr_ignore_list = ['input_names', 'output_names'] def __init__(self, name): self.name = name self.attrs = {} - self.init_attrs = {} + self.init_attrs = {"name": name} Primitive_.__init__(self, name, self) if hasattr(self.__class__, '__mindspore_signature__'): sig = self._fill_signature(self.__class__.__mindspore_signature__) @@ -165,6 +166,16 @@ class Primitive(Primitive_): def __setstate__(self, d): self.__dict__.update(d) + def __deepcopy__(self, memo): + return type(self)(**self.init_attrs) + + def __repr__(self): + attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list]) + info_str = f'Prim[{self.name}]' + if attr: + info_str += f'<{attr}>' + return info_str + def init_prim_io_names(self, inputs, outputs): """ Initializes inputs and outpus name of Tensor or attributes. @@ -185,8 +196,8 @@ class PrimitiveWithInfer(Primitive): There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(), infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority - to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describle shape - and type infer logic. The infer_value() is used for constant propogation. + to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe shape + and type infer logic. The infer_value() is used for constant propagation. Args: name (str): Name for current Primitive. @@ -288,6 +299,7 @@ def prim_attr_register(fn): bound_args.apply_defaults() arguments = bound_args.arguments del arguments['self'] + del self.init_attrs['name'] for name in arguments: value = arguments[name] self.add_prim_attr(name, value) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 5caa5a4440..9de7892eef 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -14,12 +14,23 @@ # ============================================================================ """aware quantization.""" +import copy import re -from ... import nn -from ... import ops + +import numpy as np + +from ... import log as logger +from ... import nn, ops from ..._checkparam import ParamValidator as validator from ..._checkparam import Rel +from ...common import Tensor +from ...common import dtype as mstype +from ...common.api import _executor from ...nn.layer import quant +from ...ops import functional as F +from ...ops.operations import _inner_ops as inner +from ...train import serialization +from . import quant_utils _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, nn.ReLU6: quant.ReLU6Quant, @@ -27,25 +38,21 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, nn.HSwish: quant.HSwishQuant} -class _AddFakeQuantInputOutput(nn.Cell): +class _AddFakeQuantInput(nn.Cell): """ Add FakeQuant at input and output of the Network. Only support one input and one output case. """ def __init__(self, network, quant_delay=0): - super(_AddFakeQuantInputOutput, self).__init__(auto_prefix=False) + super(_AddFakeQuantInput, self).__init__(auto_prefix=False) self.network = network self.fake_quant_input = quant.FakeQuantWithMinMax( min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) self.fake_quant_input.update_parameters_name('fake_quant_input') - self.fake_quant_output = quant.FakeQuantWithMinMax( - min_init=-6, max_init=6, quant_delay=quant_delay, ema=True) - self.fake_quant_output.update_parameters_name('fake_quant_output') def construct(self, data): data = self.fake_quant_input(data) output = self.network(data) - output = self.fake_quant_output(output) return output @@ -99,6 +106,8 @@ class ConvertToQuantNetwork: self.per_channel = validator.check_bool("per channel", per_channel) self.symmetric = validator.check_bool("symmetric", symmetric) self.narrow_range = validator.check_bool("narrow range", narrow_range) + self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv, + quant.DenseBnAct: self._convert_dense} def _convert_op_name(self, name): pattern = re.compile(r'([A-Z]{1})') @@ -110,6 +119,7 @@ class ConvertToQuantNetwork: def run(self): self.network.update_cell_prefix() network = self._convert_subcells2quant(self.network) + network = _AddFakeQuantInput(network) return network def _convert_subcells2quant(self, network): @@ -122,15 +132,9 @@ class ConvertToQuantNetwork: subcell = cells[name] if subcell == network: continue - elif isinstance(subcell, quant.Conv2dBnAct): + elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)): prefix = subcell.param_prefix - new_subcell = self._convert_conv(subcell) - new_subcell.update_parameters_name(prefix + '.') - network.insert_child_to_cell(name, new_subcell) - change = True - elif isinstance(subcell, quant.DenseBnAct): - prefix = subcell.param_prefix - new_subcell = self._convert_dense(subcell) + new_subcell = self._convert_method_map[type(subcell)](subcell) new_subcell.update_parameters_name(prefix + '.') network.insert_child_to_cell(name, new_subcell) change = True @@ -199,10 +203,12 @@ class ConvertToQuantNetwork: symmetric=self.symmetric, narrow_range=self.narrow_range) subcell.conv = conv_inner - if subcell.activation is not None: + if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) else: - subcell = _AddFakeQuantAfterSubCell(subcell) + subcell.has_act = True + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, + quant_delay=self.quant_delay) return subcell def _convert_dense(self, subcell): @@ -217,8 +223,12 @@ class ConvertToQuantNetwork: per_channel=self.per_channel, num_bits=self.weight_bits) subcell.dense = dense_inner - if subcell.activation is not None: + if subcell.has_act and subcell.activation is not None: subcell.activation = self._convert_activation(subcell.activation) + else: + subcell.has_act = True + subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits, + quant_delay=self.quant_delay) return subcell def _convert_activation(self, activation): @@ -229,6 +239,147 @@ class ConvertToQuantNetwork: return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay) +class ExportQuantNetworkDeploy: + """ + Convert quantization aware network to deploy network. + + Args: + network (Cell): MindSpore network produced by `convert_quant_network`. + inputs (Tensor): Inputs of the `network`. + + Returns: + Cell, converted network. + """ + __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] + + def __init__(self, + network, + *inputs): + network = validator.check_isinstance('network', network, (nn.Cell,)) + self.data_type = mstype.int8 + self.network = copy.deepcopy(network) + self.all_paramters = {p.name: p for p in self.network.get_parameters()} + self.get_inputs_table(inputs) + + def get_inputs_table(self, inputs): + """Get the support info for quant export.""" + phase_name = 'export_quant' + graph_id, _ = _executor.compile(self.network, *inputs, phase=phase_name, do_convert=False) + self.quant_info_table = _executor.fetch_info_for_quant_export(graph_id) + + def run(self): + """Start to convert.""" + self.network.update_cell_prefix() + network = self.network + if isinstance(network, _AddFakeQuantInput): + network = network.network + network = self._convert_quant2deploy(network) + return network + + def _get_quant_block(self, cell_core, activation, fake_quant_a_out): + """convet network's quant subcell to deploy subcell""" + # Calculate the scale and zero point + w_minq_name = cell_core.fake_quant_weight.minq.name + np_type = mstype.dtype_to_nptype(self.data_type) + scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type) + scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type) + info = self.quant_info_table.get(w_minq_name, None) + if info: + fack_quant_a_in_op, minq_name = info + maxq = self.all_paramters[minq_name[:-4] + "maxq"] + minq = self.all_paramters[minq_name] + scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) + else: + logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") + return None + + # Build the `Quant` `Dequant` op. + # AscendQuant only support perlayer version. Need check here. + quant_op = inner.AscendQuant(float(scale_a_in), float(zp_a_in)) + sqrt_mode = False + scale_deq = scale_a_out * scale_w + if scale_deq < 2 ** -14: + scale_deq = np.sqrt(scale_deq) + sqrt_mode = True + dequant_op = inner.AscendDequant(sqrt_mode) + + # get op + op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv + if isinstance(activation, _AddFakeQuantAfterSubCell): + activation = activation.subcell + elif hasattr(activation, "get_origin"): + activation = activation.get_origin() + + # get the `weight` and `bias` + weight = cell_core.weight.data.asnumpy() + bias = None + if isinstance(cell_core, (quant.DenseQuant, quant.Conv2dQuant)): + if cell_core.has_bias: + bias = cell_core.bias.data.asnumpy() + elif isinstance(cell_core, quant.Conv2dBatchNormQuant): + weight, bias = quant_utils.fold_batchnorm(weight, cell_core) + + # apply the quant + weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type) + if bias is not None: + bias = Tensor(scale_a_in * scale_w * bias, mstype.int32) + scale_deq = Tensor(scale_deq, mstype.float16) + block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) + return block + + def _convert_quant2deploy(self, network): + """Convet network's all quant subcell to deploy subcell.""" + cells = network.name_cells() + change = False + for name in cells: + subcell = cells[name] + if subcell == network: + continue + cell_core = None + fake_quant_act = None + activation = None + if isinstance(subcell, quant.Conv2dBnAct): + cell_core = subcell.conv + activation = subcell.activation + fake_quant_act = activation.fake_quant_act + elif isinstance(subcell, quant.DenseBnAct): + cell_core = subcell.dense + activation = subcell.activation + fake_quant_act = activation.fake_quant_act + if cell_core is not None: + new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) + if new_subcell: + prefix = subcell.param_prefix + new_subcell.update_parameters_name(prefix + '.') + network.insert_child_to_cell(name, new_subcell) + change = True + elif isinstance(subcell, _AddFakeQuantAfterSubCell): + op = subcell.subcell + if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive): + network.__delattr__(name) + network.__setattr__(name, op) + change = True + else: + self._convert_quant2deploy(subcell) + if isinstance(network, nn.SequentialCell) and change: + network.cell_list = list(network.cells()) + return network + + +def export_geir(network, *inputs, file_name): + """ + Exports MindSpore quant predict model to deploy with GEIR. + + Args: + network (Cell): MindSpore network produced by `convert_quant_network`. + inputs (Tensor): Inputs of the `network`. + file_name (str): File name of model to export. + """ + exporter = ExportQuantNetworkDeploy(network, *inputs) + deploy_net = exporter.run() + serialization.export(deploy_net, *inputs, file_name=file_name, file_format="GEIR") + + def convert_quant_network(network, quant_delay=0, bn_fold=False, diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index 50927b0cad..c9e6ac92e1 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""quantization utils.""" +"""Quantization utils.""" import numpy as np @@ -24,22 +24,19 @@ def cal_quantization_params(input_min, symmetric=False, narrow_range=False): r""" - calculate quantization params for scale and zero point. + Calculate quantization params for scale and zero point. Args: - input_min (int, list): The dimension of channel or 1. - input_max (int, list): The dimension of channel or 1. + input_min (numpy.ndarray): The dimension of channel or 1. + input_max (numpy.ndarray): The dimension of channel or 1. data_type (numpy type) : Can ben numpy int8, numpy uint8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. symmetric (bool): Quantization algorithm use symmetric or not. Default: False. narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. - Outputs: - scale (int, list): quantization param. - zero point (int, list): quantization param. - - Examples: - >>> scale, zp = cal_quantization_params([1, 2, 1], [-2, 0, -1], 8, False, False) + Returns: + scale (numpy.ndarray): quantization param. + zero point (numpy.ndarray): quantization param. """ input_max = np.maximum(0.0, input_max) input_min = np.minimum(0.0, input_min) @@ -92,27 +89,103 @@ def weight2int(data, scale, zero_point): r""" - calculate int8/uint8 weight from fp32. the formula is defined as: + Calculate int8/uint8 weight from fp32. the formula is defined as: .. math:: - int8/uint8 = round(float/scale) + offset Args: - data (int, list): The dimension of channel or 1. Should be NCHW. - scale (int, list): The dimension of channel or 1. - zero_point (int, list): The dimension of channel or 1. + data (numpy.ndarray): The dimension of channel or 1. Should be NCHW. + scale (numpy.ndarray): The dimension of channel or 1. + zero_point (numpy.ndarray): The dimension of channel or 1. - Outputs: - weight (int, list): The dimension of channel or 1. - - Examples: - >>> weight = weight2int([1, 2, 1], 1, 0) + Returns: + weight (numpy.ndarray): The dimension of channel or 1. """ if scale.shape != zero_point.shape: raise ValueError("scale and zero_point should have the same shape.") if scale.shape[0] > 0: - scale = scale.reshape(1, -1, 1, 1) - zero_point = zero_point.reshape(1, -1, 1, 1) + scale = scale.reshape(1, -1) + zero_point = zero_point.reshape(1, -1) return np.round((data/scale) + zero_point) + + +def scale_zp_from_fack_quant_cell(cell, data_type): + r""" + Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`. + + Args: + cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax` + data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`. + + Returns: + scale (numpy.ndarray): quantization param. + zero point (numpy.ndarray): quantization param. + """ + minq = cell.minq.data.asnumpy() + maxq = cell.maxq.data.asnumpy() + op = cell.fake_quant + + scale, zp = cal_quantization_params( + minq, maxq, data_type, + num_bits=op.num_bits, + symmetric=op.symmetric, + narrow_range=op.narrow_range) + return scale, zp + + +def scale_zp_from_data(op, minq, maxq, data_type): + r""" + Get calculate quantization params for scale and zero point. + + Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. + + Args: + op (Primitive): Fake quant primitive `mindspore.ops.operation.FakeQuantPerLayer` or + `mindspore.ops.operation.FakeQuantPerChannel` + minq (Parameter): Parameter `minq` of `mindspore.nn.layer.FakeQuantWithMinMax` + maxq (Parameter): Parameter `maxq` of `mindspore.nn.layer.FakeQuantWithMinMax` + data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`. + + Returns: + scale (numpy.ndarray): quantization param. + zero point (numpy.ndarray): quantization param. + """ + minq = minq.data.asnumpy() + maxq = maxq.data.asnumpy() + + scale, zp = cal_quantization_params( + minq, maxq, data_type, + num_bits=op.num_bits, + symmetric=op.symmetric, + narrow_range=op.narrow_range) + return scale, zp + + +def fold_batchnorm(weight, cell_quant): + r""" + Fold the batchnorm in `Conv2dBatchNormQuant` to weight. + + Calculate from `FakeQuantWithMinMax`'s Parameter or Fake quant primitive. + + Args: + weight (numpy.ndarray): Weight of `cell_quant`. + cell_quant (Cell): Object of `mindspore.nn.layer.Conv2dBatchNormQuant`. + + Returns: + weight (numpy.ndarray): Folded weight. + bias (numpy.ndarray): Folded bias. + """ + variance = cell_quant.moving_variance.data.asnumpy() + mean = cell_quant.moving_mean.data.asnumpy() + gamma = cell_quant.gamma.data.asnumpy() + beta = cell_quant.beta.data.asnumpy() + epsilon = cell_quant.eps + sigma = np.sqrt(variance + epsilon) + gamma = gamma.reshape(-1, 1, 1, 1) + sigma = sigma.reshape(-1, 1, 1, 1) + mean = mean.reshape(-1, 1, 1, 1) + weight = weight * gamma / sigma + bias = beta - gamma * mean / sigma + return weight, bias diff --git a/tests/st/model_zoo_tests/yolov3/test_yolov3.py b/tests/st/model_zoo_tests/yolov3/test_yolov3.py index 6b4057db18..126c66a6f3 100644 --- a/tests/st/model_zoo_tests/yolov3/test_yolov3.py +++ b/tests/st/model_zoo_tests/yolov3/test_yolov3.py @@ -55,7 +55,7 @@ def init_net_param(network, init_value='ones'): params = network.trainable_params() for p in params: if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name: - p.set_parameter_data(initializer(init_value, p.data.shape(), p.data.dtype())) + p.set_parameter_data(initializer(init_value, p.data.shape, p.data.dtype)) class ModelCallback(Callback): def __init__(self): diff --git a/tests/ut/python/train/quant/test_quant.py b/tests/ut/python/train/quant/test_quant.py index e299c7b9fc..6098354cb0 100644 --- a/tests/ut/python/train/quant/test_quant.py +++ b/tests/ut/python/train/quant/test_quant.py @@ -13,9 +13,14 @@ # limitations under the License. # ============================================================================ """ tests for quant """ +import numpy as np +import pytest + import mindspore.context as context +from mindspore import Tensor from mindspore import nn - +from mindspore.train.quant import quant as qat +from mobilenetv2_combined import MobileNetV2 context.set_context(mode=context.GRAPH_MODE, device_target="GPU") @@ -37,23 +42,45 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10): super(LeNet5, self).__init__() self.num_class = num_class - self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6') - self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu') + self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, batchnorm=True, activation='relu6', pad_mode="valid") + self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu', pad_mode="valid") self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu') self.fc3 = nn.DenseBnAct(84, self.num_class) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.flattern = nn.Flatten() + self.flatten = nn.Flatten() def construct(self, x): x = self.conv1(x) - x = self.bn(x) - x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) x = self.max_pool2d(x) - x = self.flattern(x) + x = self.flatten(x) x = self.fc1(x) x = self.fc2(x) x = self.fc3(x) return x + + +@pytest.mark.skip(reason="no `te.lang.cce` in ut env") +def test_qat_lenet(): + img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) + net = LeNet5() + net = qat.convert_quant_network( + net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8) + # should load the checkpoint. mock here + for param in net.get_parameters(): + param.init_data() + qat.export_geir(net, img, file_name="quant.pb") + + +@pytest.mark.skip(reason="no `te.lang.cce` in ut env") +def test_qat_mobile(): + net = MobileNetV2() + img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) + net = qat.convert_quant_network( + net, quant_delay=0, bn_fold=True, freeze_bn=10000, weight_bits=8, act_bits=8) + # should load the checkpoint. mock here + for param in net.get_parameters(): + param.init_data() + qat.export_geir(net, img, file_name="quant.pb")