From 776d094c5be6eba1a9c67d3bb611b9b19d7a6236 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Fri, 7 Aug 2020 11:10:41 +0800 Subject: [PATCH] quant export fix up for atc tools --- mindspore/common/parameter.py | 1 + mindspore/train/quant/quant.py | 12 +++++++++--- model_zoo/official/nlp/bert/run_classifier.py | 2 +- model_zoo/official/nlp/bert/run_ner.py | 2 +- model_zoo/official/nlp/bert/run_pretrain.py | 4 ++-- model_zoo/official/nlp/bert/run_squad.py | 2 +- model_zoo/official/nlp/bert_thor/run_pretrain.py | 4 ++-- .../official/nlp/tinybert/run_general_distill.py | 2 +- model_zoo/official/nlp/tinybert/run_task_distill.py | 4 ++-- 9 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index f29f3a0a5c..df7ebb2d9d 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -270,6 +270,7 @@ class Parameter(MetaTensor): "Update the parameter by a Tensor." if isinstance(self, Tensor): # for Tensor same shape: + self.init_flag = False return self.assign_value(data) # create a new tensor return Parameter(data, self.name, self.requires_grad) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index aae2fd74b2..7a8b8c9bf5 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -29,6 +29,7 @@ from ...common import dtype as mstype from ...common.api import _executor from ...nn.layer import quant from ...ops import functional as F +from ...ops import operations as P from ...ops.operations import _inner_ops as inner from ...train import serialization from . import quant_utils @@ -366,8 +367,6 @@ class ExportToQuantInferNetwork: sqrt_mode = True dequant_op = inner.Dequant(sqrt_mode) - # get op - op_core = cell_core.matmul if isinstance(cell_core, quant.DenseQuant) else cell_core.conv if isinstance(activation, _AddFakeQuantAfterSubCell): activation = activation.subcell elif hasattr(activation, "get_origin"): @@ -383,10 +382,17 @@ class ExportToQuantInferNetwork: weight, bias = quant_utils.fold_batchnorm(weight, cell_core) # apply the quant - weight = Tensor(quant_utils.weight2int(weight, scale_w, zp_w), self.data_type) + weight = quant_utils.weight2int(weight, scale_w, zp_w) if bias is not None: bias = Tensor(scale_a_in * scale_w * bias, mstype.int32) scale_deq = Tensor(scale_deq, mstype.float16) + # get op + if isinstance(cell_core, quant.DenseQuant): + op_core = P.MatMul() + weight = np.transpose(weight) + else: + op_core = cell_core.conv + weight = Tensor(weight, self.data_type) block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation) return block diff --git a/model_zoo/official/nlp/bert/run_classifier.py b/model_zoo/official/nlp/bert/run_classifier.py index c3663a5727..d2278bbc3c 100644 --- a/model_zoo/official/nlp/bert/run_classifier.py +++ b/model_zoo/official/nlp/bert/run_classifier.py @@ -50,7 +50,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin power=optimizer_cfg.AdamWeightDecay.power) params = net_with_loss.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}] diff --git a/model_zoo/official/nlp/bert/run_ner.py b/model_zoo/official/nlp/bert/run_ner.py index 1ea6893945..b311950315 100644 --- a/model_zoo/official/nlp/bert/run_ner.py +++ b/model_zoo/official/nlp/bert/run_ner.py @@ -52,7 +52,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin power=optimizer_cfg.AdamWeightDecay.power) params = network.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}] optimizer = AdamWeightDecay(group_params, lr_schedule, eps=optimizer_cfg.AdamWeightDecay.eps) diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 1f31ff4015..6b4cb1548a 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -116,7 +116,7 @@ def run_pretrain(): power=cfg.Lamb.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.Lamb.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, {'params': other_params}, {'order_params': params}] @@ -132,7 +132,7 @@ def run_pretrain(): power=cfg.AdamWeightDecay.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}] diff --git a/model_zoo/official/nlp/bert/run_squad.py b/model_zoo/official/nlp/bert/run_squad.py index 972f9dcdfc..a026408e7c 100644 --- a/model_zoo/official/nlp/bert/run_squad.py +++ b/model_zoo/official/nlp/bert/run_squad.py @@ -52,7 +52,7 @@ def do_train(dataset=None, network=None, load_checkpoint_path="", save_checkpoin power=optimizer_cfg.AdamWeightDecay.power) params = network.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}] diff --git a/model_zoo/official/nlp/bert_thor/run_pretrain.py b/model_zoo/official/nlp/bert_thor/run_pretrain.py index 08161c7a13..0ec84545db 100644 --- a/model_zoo/official/nlp/bert_thor/run_pretrain.py +++ b/model_zoo/official/nlp/bert_thor/run_pretrain.py @@ -137,7 +137,7 @@ def run_pretrain(): power=cfg.Lamb.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.Lamb.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, {'params': other_params}, {'order_params': params}] @@ -153,7 +153,7 @@ def run_pretrain(): power=cfg.AdamWeightDecay.power) params = net_with_loss.trainable_params() decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}] diff --git a/model_zoo/official/nlp/tinybert/run_general_distill.py b/model_zoo/official/nlp/tinybert/run_general_distill.py index c0e1044773..50e586f0af 100644 --- a/model_zoo/official/nlp/tinybert/run_general_distill.py +++ b/model_zoo/official/nlp/tinybert/run_general_distill.py @@ -99,7 +99,7 @@ def run_general_distill(): power=common_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(common_cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': common_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}] diff --git a/model_zoo/official/nlp/tinybert/run_task_distill.py b/model_zoo/official/nlp/tinybert/run_task_distill.py index 12a3acda48..9469c475d2 100644 --- a/model_zoo/official/nlp/tinybert/run_task_distill.py +++ b/model_zoo/official/nlp/tinybert/run_task_distill.py @@ -107,7 +107,7 @@ def run_predistill(): power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}] @@ -165,7 +165,7 @@ def run_task_distill(ckpt_file): power=optimizer_cfg.AdamWeightDecay.power) params = netwithloss.trainable_params() decay_params = list(filter(optimizer_cfg.AdamWeightDecay.decay_filter, params)) - other_params = list(filter(lambda x: x not in decay_params, params)) + other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': optimizer_cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}]