!2836 bug fix in auto create quant graph in master

Merge pull request !2836 from chenzhongming/master
pull/2836/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit d48a5a4863

@ -1193,9 +1193,9 @@ class QuantBlock(Cell):
self.dequant = dequant_op self.dequant = dequant_op
self.dequant_scale = dequant_scale self.dequant_scale = dequant_scale
self.bias = bias self.bias = bias
self.has_bias = bias is None self.has_bias = bias is not None
self.activation = activation self.activation = activation
self.has_act = activation is None self.has_act = activation is not None
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
def construct(self, x): def construct(self, x):

@ -86,7 +86,7 @@ class LossMonitor(Callback):
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}]".format( "loss: [{:5.4f}], avg los: [{:5.4f}], time: [{:5.4f}ms]".format(
cb_params.cur_epoch_num, cb_params.epoch_num, cb_params.cur_epoch_num, cb_params.epoch_num,
cur_step_in_epoch, int(cb_params.batch_num), cur_step_in_epoch, int(cb_params.batch_num),
step_loss, np.mean(self.losses), step_loss, np.mean(self.losses),

@ -33,7 +33,6 @@ from ...ops.operations import _inner_ops as inner
from ...train import serialization from ...train import serialization
from . import quant_utils from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant, _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.ReLU6: quant.ReLU6Quant, nn.ReLU6: quant.ReLU6Quant,
nn.HSigmoid: quant.HSigmoidQuant, nn.HSigmoid: quant.HSigmoidQuant,
@ -178,7 +177,6 @@ class ConvertToQuantNetwork:
dilation=conv_inner.dilation, dilation=conv_inner.dilation,
group=conv_inner.group, group=conv_inner.group,
eps=bn_inner.eps, eps=bn_inner.eps,
momentum=1 - bn_inner.momentum,
quant_delay=self.weight_qdelay, quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn, freeze_bn=self.freeze_bn,
per_channel=self.weight_channel, per_channel=self.weight_channel,
@ -268,16 +266,16 @@ class ConvertToQuantNetwork:
narrow_range=self.act_range) narrow_range=self.act_range)
class ExportQuantNetworkDeploy: class ExportToQuantInferNetwork:
""" """
Convert quantization aware network to deploy network. Convert quantization aware network to infer network.
Args: Args:
network (Cell): MindSpore network produced by `convert_quant_network`. network (Cell): MindSpore network API `convert_quant_network`.
inputs (Tensor): Inputs of the `network`. inputs (Tensor): Input tensors of the `quantization aware training network`.
Returns: Returns:
Cell, converted network. Cell, GEIR backend Infer network.
""" """
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
@ -287,7 +285,7 @@ class ExportQuantNetworkDeploy:
network = validator.check_isinstance('network', network, (nn.Cell,)) network = validator.check_isinstance('network', network, (nn.Cell,))
self.data_type = mstype.int8 self.data_type = mstype.int8
self.network = copy.deepcopy(network) self.network = copy.deepcopy(network)
self.all_paramters = {p.name: p for p in self.network.get_parameters()} self.all_parameters = {p.name: p for p in self.network.get_parameters()}
self.get_inputs_table(inputs) self.get_inputs_table(inputs)
def get_inputs_table(self, inputs): def get_inputs_table(self, inputs):
@ -315,8 +313,8 @@ class ExportQuantNetworkDeploy:
info = self.quant_info_table.get(w_minq_name, None) info = self.quant_info_table.get(w_minq_name, None)
if info: if info:
fack_quant_a_in_op, minq_name = info fack_quant_a_in_op, minq_name = info
maxq = self.all_paramters[minq_name[:-4] + "maxq"] maxq = self.all_parameters[minq_name[:-4] + "maxq"]
minq = self.all_paramters[minq_name] minq = self.all_parameters[minq_name]
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type) scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, maxq, minq, np_type)
else: else:
logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}") logger.warning(f"Do not find `fake_quant` from input with `fack_quant.minq` {w_minq_name}")
@ -357,7 +355,7 @@ class ExportQuantNetworkDeploy:
return block return block
def _convert_quant2deploy(self, network): def _convert_quant2deploy(self, network):
"""Convet network's all quant subcell to deploy subcell.""" """Convert network's all quant subcell to deploy subcell."""
cells = network.name_cells() cells = network.name_cells()
change = False change = False
for name in cells: for name in cells:
@ -395,18 +393,26 @@ class ExportQuantNetworkDeploy:
return network return network
def export_geir(network, *inputs, file_name): def export(network, *inputs, file_name, file_format='GEIR'):
""" """
Exports MindSpore quant predict model to deploy with GEIR. Exports MindSpore quantization predict model to deploy with GEIR.
Args: Args:
network (Cell): MindSpore network produced by `convert_quant_network`. network (Cell): MindSpore network produced by `convert_quant_network`.
inputs (Tensor): Inputs of the `network`. inputs (Tensor): Inputs of the `quantization aware training network`.
file_name (str): File name of model to export. file_name (str): File name of model to export.
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
""" """
exporter = ExportQuantNetworkDeploy(network, *inputs) supported_formats = ['GEIR']
deploy_net = exporter.run()
serialization.export(deploy_net, *inputs, file_name=file_name, file_format="GEIR") if file_format not in supported_formats:
raise ValueError('Illegal file format {}.'.format(file_format))
if file_format == 'GEIR':
exporter = ExportToQuantInferNetwork(network, *inputs)
deploy_net = exporter.run()
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
def convert_quant_network(network, def convert_quant_network(network,
@ -443,6 +449,7 @@ def convert_quant_network(network,
Cell, Network which has change to quantization aware training network cell. Cell, Network which has change to quantization aware training network cell.
""" """
support_device = ["Ascend", "GPU"] support_device = ["Ascend", "GPU"]
def convert2list(name, value): def convert2list(name, value):
if not isinstance(value, list) and not isinstance(value, tuple): if not isinstance(value, list) and not isinstance(value, tuple):
value = [value] value = [value]
@ -457,7 +464,7 @@ def convert_quant_network(network,
narrow_range = convert2list("narrow range", narrow_range) narrow_range = convert2list("narrow range", narrow_range)
if context.get_context('device_target') not in support_device: if context.get_context('device_target') not in support_device:
raise KeyError("Not support {} backend.".format(context.get_context('device_target'))) raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
net = ConvertToQuantNetwork(network=network, net = ConvertToQuantNetwork(network=network,
quant_delay=quant_delay, quant_delay=quant_delay,

@ -160,7 +160,10 @@ def load_checkpoint(ckpt_file_name, net=None):
if not isinstance(ckpt_file_name, str): if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpt_file_name must be string.") raise ValueError("The ckpt_file_name must be string.")
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": if not os.path.exists(ckpt_file_name):
raise ValueError("The checkpoint file is not exist.")
if ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.") raise ValueError("Please input the correct checkpoint file name.")
if os.path.getsize(ckpt_file_name) == 0: if os.path.getsize(ckpt_file_name) == 0:

@ -57,7 +57,7 @@ if __name__ == "__main__":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load check point into network # load check point into network
param_dict = load_checkpoint(args.ckpt_path, network.type) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
print("============== Starting Testing ==============") print("============== Starting Testing ==============")

@ -49,7 +49,7 @@ if __name__ == "__main__":
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion netwrok to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define loss # define loss

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
export quantization aware training network to infer `GEIR` backend.
"""
import argparse
import numpy as np
import mindspore
from mindspore import Tensor
from mindspore import context
from mindspore.train.quant import quant
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend",
choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="",
help='if mode is test, must provide path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=True,
help='dataset_sink_mode is False or True')
args = parser.parse_args()
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
# export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mindspore.float32)
quant.export(network, inputs, file_name="lenet_quant", file_format='GEIR')

@ -22,7 +22,7 @@ import os
import argparse import argparse
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset from src.dataset import create_dataset
@ -54,7 +54,6 @@ if __name__ == "__main__":
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor # call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
@ -63,6 +62,6 @@ if __name__ == "__main__":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============") print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode) dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============") print("============== End Training ==============")

@ -23,7 +23,7 @@ import argparse
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.quant import quant from mindspore.train.quant import quant
@ -51,20 +51,19 @@ if __name__ == "__main__":
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, network.type) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define network loss # define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization # define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor # call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
@ -73,6 +72,6 @@ if __name__ == "__main__":
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============") print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], model.train(cfg['epoch_size'], ds_train, callbacks=[ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode) dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============") print("============== End Training ==============")

Loading…
Cancel
Save