quant export fix up for atc tools

pull/4092/head
Wei Luning 5 years ago
parent a2b3235692
commit 776d094c5b

@ -270,6 +270,7 @@ class Parameter(MetaTensor):
"Update the parameter by a Tensor."
if isinstance(self, Tensor):
# for Tensor same shape:
self.init_flag = False
return self.assign_value(data)
# create a new tensor
return Parameter(data, self.name, self.requires_grad)

@ -29,6 +29,7 @@ from ...common import dtype as mstype
from ...common.api import _executor
from ...nn.layer import quant
from ...ops import functional as F
from ...ops import operations as P
from ...ops.operations import _inner_ops as inner
from ...train import serialization
from . import quant_utils
@ -366,8 +367,6 @@ class ExportToQuantInferNetwork:
sqrt_mode = True
dequant_op = inner.Dequant(sqrt_mode)
# get op
op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv
if isinstance(activation, _AddFakeQuantAfterSubCell):
activation = activation.subcell
elif hasattr(activation, "get_origin"):
@ -383,10 +382,17 @@ class ExportToQuantInferNetwork:
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
# apply the quant
weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type)
weight = quant_utils.weight2int(weight, scale_w, zp_w)
if bias is not None:
bias = Tensor(scale_a_in * scale_w * bias, mstype.int32)
scale_deq = Tensor(scale_deq, mstype.float16)
# get op
if isinstance(cell_core, quant.DenseQuant):
op_core = P.MatMul()
weight = np.transpose(weight)
else:
op_core = cell_core.conv
weight = Tensor(weight, self.data_type)
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block

@ -50,7 +50,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
power=optimizer_cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]

@ -52,7 +52,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
power=optimizer_cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]
optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps)

@ -116,7 +116,7 @@ def run_pretrain():
power=cfg.Lamb.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params},
{'order_params': params}]
@ -132,7 +132,7 @@ def run_pretrain():
power=cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]

@ -52,7 +52,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin
power=optimizer_cfg.AdamWeightDecay.power)
params = network.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0}]

@ -137,7 +137,7 @@ def run_pretrain():
power=cfg.Lamb.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.Lamb.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
{'params': other_params},
{'order_params': params}]
@ -153,7 +153,7 @@ def run_pretrain():
power=cfg.AdamWeightDecay.power)
params = net_with_loss.trainable_params()
decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]

@ -99,7 +99,7 @@ def run_general_distill():
power=common_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]

@ -107,7 +107,7 @@ def run_predistill():
power=optimizer_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]
@ -165,7 +165,7 @@ def run_task_distill(ckpt_file):
power=optimizer_cfg.AdamWeightDecay.power)
params = netwithloss.trainable_params()
decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params))
other_params = list(filter(lambda x: x not in decay_params, params))
other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay},
{'params': other_params, 'weight_decay': 0.0},
{'order_params': params}]

Loading…
Cancel
Save