aware quantization training auto create graph

pull/2337/head
chenzomi 5 years ago
parent 182215e060
commit c268c88220

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -279,7 +279,7 @@ class FakeQuantWithMinMax(Cell):
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1.
out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
@ -407,7 +407,7 @@ class Conv2dBatchNormQuant(Cell):
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -584,7 +584,7 @@ class Conv2dQuant(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -694,7 +694,7 @@ class DenseQuant(Cell):
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -797,6 +797,7 @@ class ReLUQuant(_QuantActivation):
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -816,6 +817,7 @@ class ReLUQuant(_QuantActivation):
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
symmetric=False,
narrow_range=False):
super(ReLUQuant, self).__init__()
@ -824,6 +826,7 @@ class ReLUQuant(_QuantActivation):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
@ -850,6 +853,7 @@ class ReLU6Quant(_QuantActivation):
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -869,6 +873,7 @@ class ReLU6Quant(_QuantActivation):
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
symmetric=False,
narrow_range=False):
super(ReLU6Quant, self).__init__()
@ -877,6 +882,7 @@ class ReLU6Quant(_QuantActivation):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
@ -900,6 +906,7 @@ class HSwishQuant(_QuantActivation):
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -919,6 +926,7 @@ class HSwishQuant(_QuantActivation):
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
symmetric=False,
narrow_range=False):
super(HSwishQuant, self).__init__()
@ -927,6 +935,7 @@ class HSwishQuant(_QuantActivation):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
@ -935,6 +944,7 @@ class HSwishQuant(_QuantActivation):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
@ -959,6 +969,7 @@ class HSigmoidQuant(_QuantActivation):
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -978,6 +989,7 @@ class HSigmoidQuant(_QuantActivation):
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
symmetric=False,
narrow_range=False):
super(HSigmoidQuant, self).__init__()
@ -986,6 +998,7 @@ class HSigmoidQuant(_QuantActivation):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
symmetric=symmetric,
narrow_range=narrow_range)
self.fake_quant_act_after = FakeQuantWithMinMax(min_init=-6,
@ -993,6 +1006,7 @@ class HSigmoidQuant(_QuantActivation):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
@ -1017,6 +1031,7 @@ class TensorAddQuant(Cell):
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -1037,6 +1052,7 @@ class TensorAddQuant(Cell):
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
symmetric=False,
narrow_range=False):
super(TensorAddQuant, self).__init__()
@ -1045,6 +1061,7 @@ class TensorAddQuant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)
@ -1066,6 +1083,7 @@ class MulQuant(Cell):
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
@ -1081,6 +1099,7 @@ class MulQuant(Cell):
num_bits=8,
quant_delay=0,
ema_decay=0.999,
per_channel=False,
symmetric=False,
narrow_range=False):
super(MulQuant, self).__init__()
@ -1089,6 +1108,7 @@ class MulQuant(Cell):
num_bits=num_bits,
quant_delay=quant_delay,
ema=True,
per_channel=per_channel,
ema_decay=ema_decay,
symmetric=symmetric,
narrow_range=narrow_range)

@ -14,6 +14,7 @@
# ============================================================================
"""LossMonitor Callback class."""
import time
import numpy as np
from mindspore.common.tensor import Tensor
@ -31,32 +32,62 @@ class LossMonitor(Callback):
Args:
per_print_times (int): Print loss every times. Default: 1.
lr_init (numpy array): train learning rate. Default: None.
Raises:
ValueError: If print_step is not int or less than zero.
Examples:
>>> LossMonitor(100, lr_init=Tensor([0.05]*100).asnumpy())
"""
def __init__(self, per_print_times=1):
def __init__(self, per_print_times=1, lr_init=None):
super(LossMonitor, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.lr_init = lr_init
def epoch_begin(self, run_context):
self.losses = []
self.epoch_time = time.time()
def epoch_end(self, run_context):
cb_params = run_context.original_args()
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / cb_params.batch_num
print("Epoch time: {:5.3f}, per step time: {:5.3f}, "
"avg loss: {:5.3f}".format(epoch_mseconds,
per_step_mseconds,
np.mean(self.losses)))
print("*" * 60)
def step_begin(self, run_context):
self.step_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs
step_mseconds = (time.time() - self.step_time) * 1000
step_loss = cb_params.net_outputs
if isinstance(loss, (tuple, list)):
if isinstance(loss[0], Tensor) and isinstance(loss[0].asnumpy(), np.ndarray):
loss = loss[0]
if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor):
step_loss = step_loss[0]
if isinstance(step_loss, Tensor):
step_loss = np.mean(step_loss.asnumpy())
if isinstance(loss, Tensor) and isinstance(loss.asnumpy(), np.ndarray):
loss = np.mean(loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
"Invalid loss, terminating training.".format(
cb_params.cur_epoch_num - 1, cb_params.epoch_num,
cur_step_in_epoch, cb_params.batch_num))
if isinstance(loss, float) and (np.isnan(loss) or np.isinf(loss)):
raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format(
cb_params.cur_epoch_num, cur_step_in_epoch))
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("epoch: %s step: %s, loss is %s" % (cb_params.cur_epoch_num, cur_step_in_epoch, loss), flush=True)
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
cb_params.cur_epoch_num - 1, cb_params.epoch_num,
cur_step_in_epoch, cb_params.batch_num,
step_loss, np.mean(self.losses),
step_mseconds), flush=True)

@ -32,4 +32,4 @@ class TimeMonitor(Callback):
def epoch_end(self, run_context):
epoch_mseconds = (time.time() - self.epoch_time) * 1000
per_step_mseconds = epoch_mseconds / self.data_size
print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True)
print("Epoch time: {:5.3f}, per step time: {:5.3f}".format(epoch_mseconds, per_step_mseconds), flush=True)

@ -32,6 +32,7 @@ from ...ops.operations import _inner_ops as inner
from ...train import serialization
from . import quant_utils
_ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
nn.ReLU6: quant.ReLU6Quant,
nn.HSigmoid: quant.HSigmoidQuant,
@ -61,14 +62,17 @@ class _AddFakeQuantAfterSubCell(nn.Cell):
Add FakeQuant after of the sub Cell.
"""
def __init__(self, subcell, quant_delay=0, num_bits=8):
def __init__(self, subcell, **kwargs):
super(_AddFakeQuantAfterSubCell, self).__init__(auto_prefix=False)
self.subcell = subcell
self.fake_quant_act = quant.FakeQuantWithMinMax(min_init=-6,
max_init=6,
num_bits=num_bits,
quant_delay=quant_delay,
ema=True)
ema=True,
num_bits=kwargs["num_bits"],
quant_delay=kwargs["quant_delay"],
per_channel=kwargs["per_channel"],
symmetric=kwargs["symmetric"],
narrow_range=kwargs["narrow_range"])
def construct(self, *data):
output = self.subcell(*data)
@ -82,30 +86,20 @@ class ConvertToQuantNetwork:
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self,
network,
quant_delay=0,
bn_fold=False,
freeze_bn=0,
weight_bits=8,
act_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False):
self.network = validator.check_isinstance(
'network', network, (nn.Cell,))
self.quant_delay = validator.check_integer(
"quant delay", quant_delay, 0, Rel.GE)
self.freeze_bn = validator.check_integer(
"freeze bn", freeze_bn, 0, Rel.GE)
self.weight_bits = validator.check_integer(
"weights bit", weight_bits, 0, Rel.GE)
self.act_bits = validator.check_integer(
"activations bit", act_bits, 0, Rel.GE)
self.bn_fold = validator.check_bool("bn fold", bn_fold)
self.per_channel = validator.check_bool("per channel", per_channel)
self.symmetric = validator.check_bool("symmetric", symmetric)
self.narrow_range = validator.check_bool("narrow range", narrow_range)
def __init__(self, **kwargs):
self.network = validator.check_isinstance('network', kwargs["network"], (nn.Cell,))
self.weight_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][0], 0, Rel.GE)
self.act_qdelay = validator.check_integer("quant delay", kwargs["quant_delay"][-1], 0, Rel.GE)
self.bn_fold = validator.check_bool("bn fold", kwargs["bn_fold"])
self.freeze_bn = validator.check_integer("freeze bn", kwargs["freeze_bn"], 0, Rel.GE)
self.weight_bits = validator.check_integer("weights bit", kwargs["num_bits"][0], 0, Rel.GE)
self.act_bits = validator.check_integer("activations bit", kwargs["num_bits"][-1], 0, Rel.GE)
self.weight_channel = validator.check_bool("per channel", kwargs["per_channel"][0])
self.act_channel = validator.check_bool("per channel", kwargs["per_channel"][-1])
self.weight_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][0])
self.act_symmetric = validator.check_bool("symmetric", kwargs["symmetric"][-1])
self.weight_range = validator.check_bool("narrow range", kwargs["narrow_range"][0])
self.act_range = validator.check_bool("narrow range", kwargs["narrow_range"][-1])
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}
@ -153,7 +147,12 @@ class ConvertToQuantNetwork:
add_list.append((name, attr))
for name, prim_op in add_list:
prefix = name
add_quant = _AddFakeQuantAfterSubCell(prim_op) # quant.TensorAddQuant()
add_quant = _AddFakeQuantAfterSubCell(prim_op,
num_bits=self.act_bits,
quant_delay=self.act_delay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
prefix = '.'.join([network.param_prefix, self._convert_op_name(prim_op.name)])
add_quant.update_parameters_name(prefix + '.')
del network.__dict__[name]
@ -177,13 +176,13 @@ class ConvertToQuantNetwork:
group=conv_inner.group,
eps=bn_inner.eps,
momentum=bn_inner.momentum,
quant_delay=self.quant_delay,
quant_delay=self.weight_qdelay,
freeze_bn=self.freeze_bn,
per_channel=self.per_channel,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
fake=True,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
del subcell.batchnorm
subcell.batchnorm = None
subcell.has_bn = False
@ -197,18 +196,22 @@ class ConvertToQuantNetwork:
dilation=conv_inner.dilation,
group=conv_inner.group,
has_bias=conv_inner.has_bias,
quant_delay=self.quant_delay,
per_channel=self.per_channel,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
num_bits=self.weight_bits,
symmetric=self.symmetric,
narrow_range=self.narrow_range)
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
subcell.conv = conv_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits,
quant_delay=self.quant_delay)
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
return subcell
def _convert_dense(self, subcell):
@ -219,16 +222,22 @@ class ConvertToQuantNetwork:
dense_inner = quant.DenseQuant(dense_inner.in_channels,
dense_inner.out_channels,
has_bias=dense_inner.has_bias,
quant_delay=self.quant_delay,
per_channel=self.per_channel,
num_bits=self.weight_bits)
num_bits=self.weight_bits,
quant_delay=self.weight_qdelay,
per_channel=self.weight_channel,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
subcell.dense = dense_inner
if subcell.has_act and subcell.activation is not None:
subcell.activation = self._convert_activation(subcell.activation)
else:
subcell.has_act = True
subcell.activation = _AddFakeQuantAfterSubCell(F.identity, num_bits=self.act_bits,
quant_delay=self.quant_delay)
subcell.activation = _AddFakeQuantAfterSubCell(F.identity,
num_bits=self.act_bits,
quant_delay=self.act_delay,
per_channel=self.act_channel,
symmetric=self.act_symmetric,
narrow_range=self.act_range)
return subcell
def _convert_activation(self, activation):
@ -236,7 +245,11 @@ class ConvertToQuantNetwork:
if act_class not in _ACTIVATION_MAP:
raise ValueError(
"Unsupported activation in auto Quant: ", act_class)
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits, quant_delay=self.quant_delay)
return _ACTIVATION_MAP[act_class](num_bits=self.act_bits,
quant_delay=self.act_qdelay,
per_channel=self.act_channel,
symmetric=self.weight_symmetric,
narrow_range=self.weight_range)
class ExportQuantNetworkDeploy:
@ -381,32 +394,57 @@ def export_geir(network, *inputs, file_name):
def convert_quant_network(network,
quant_delay=0,
bn_fold=False,
freeze_bn=0,
weight_bits=8,
act_bits=8,
per_channel=False,
symmetric=False,
narrow_range=False
quant_delay=(0, 0),
num_bits=(8, 8),
per_channel=(False, False),
symmetric=(False, False),
narrow_range=(False, False)
):
r"""
Create aware quantizaiton training network.
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0.
quant_delay (int): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: [0, 0]
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0.
weight_bits (int): Number of bits to use for quantizing weights. Default: 8.
act_bits (int): Number of bits to use for quantizing activations. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
num_bits (list of int): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: [8, 8]
per_channel (list of bool): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default: [False, False]
symmetric (list of bool): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on assymmetric. The first element represent weights and second
element represent data flow. Default: [False, False]
narrow_range (list of bool): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and
second element represent data flow. Default: [False, False]
Returns:
Cell, Network which has change to aware quantization training network.
Cell, Network which has change to aware quantization training network cell.
"""
net = ConvertToQuantNetwork(
network, quant_delay, bn_fold, freeze_bn, weight_bits, act_bits, per_channel, symmetric, narrow_range)
def convert2list(name, value):
if not isinstance(value, list) and not isinstance(value, tuple):
value = [value]
elif len(value) > 2:
raise ValueError("input `{}` len should less then 2".format(name))
return value
quant_delay = convert2list("quant delay", quant_delay)
num_bits = convert2list("num bits", num_bits)
per_channel = convert2list("per channel", per_channel)
symmetric = convert2list("symmetric", symmetric)
narrow_range = convert2list("narrow range", narrow_range)
net = ConvertToQuantNetwork(network=network,
quant_delay=quant_delay,
bn_fold=bn_fold,
freeze_bn=freeze_bn,
num_bits=num_bits,
per_channel=per_channel,
symmetric=symmetric,
narrow_range=narrow_range)
return net.run()

@ -54,4 +54,4 @@ if __name__ == "__main__":
cfg.batch_size,
status="test")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))
print("============== {} ==============".format(acc))

@ -61,4 +61,4 @@ if __name__ == "__main__":
cfg.batch_size,
1)
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))
print("============== {} ==============".format(acc))

@ -78,4 +78,4 @@ if __name__ == '__main__':
acc = model.eval(ds_eval, dataset_sink_mode=False)
else:
acc = model.eval(ds_eval)
print("============== Accuracy:{} ==============".format(acc))
print("============== {} ==============".format(acc))

@ -203,4 +203,4 @@ def test_train_and_eval_lenet():
print("============== Starting Testing ==============")
ds_eval = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "test"), 32, 1)
acc = model.eval(ds_eval, dataset_sink_mode=True)
print("============== Accuracy:{} ==============".format(acc))
print("============== {} ==============".format(acc))

@ -67,7 +67,7 @@ def test_qat_lenet():
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
net = LeNet5()
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
net, quant_delay=0, bn_fold=False, freeze_bn=10000, num_bits=8)
# should load the checkpoint. mock here
for param in net.get_parameters():
param.init_data()
@ -79,7 +79,7 @@ def test_qat_mobile():
net = MobileNetV2()
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=True, freeze_bn=10000, weight_bits=8, act_bits=8)
net, quant_delay=0, bn_fold=True, freeze_bn=10000, num_bits=8)
# should load the checkpoint. mock here
for param in net.get_parameters():
param.init_data()

@ -117,6 +117,7 @@ def test_loss_monitor_sink_mode():
"""Test loss monitor sink mode."""
cb_params = _InternalCallbackParam()
cb_params.cur_epoch_num = 4
cb_params.epoch_num = 4
cb_params.cur_step_num = 2
cb_params.batch_num = 2
cb_params.net_outputs = Tensor(2.0)
@ -138,6 +139,7 @@ def test_loss_monitor_normal_mode():
run_context = RunContext(cb_params)
loss_cb = LossMonitor(1)
cb_params.cur_epoch_num = 4
cb_params.epoch_num = 4
cb_params.cur_step_num = 1
cb_params.batch_num = 1
cb_params.net_outputs = Tensor(2.0)

Loading…
Cancel
Save