From 1503d1e23004a478064b03d969ddb30323fdebfc Mon Sep 17 00:00:00 2001 From: yuchaojie Date: Tue, 8 Dec 2020 09:33:59 +0800 Subject: [PATCH] use two_conv_fold for ascend quant net --- model_zoo/official/cv/mobilenetv2_quant/train.py | 3 ++- model_zoo/official/cv/resnet50_quant/train.py | 3 ++- model_zoo/official/cv/yolov3_darknet53_quant/train.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/model_zoo/official/cv/mobilenetv2_quant/train.py b/model_zoo/official/cv/mobilenetv2_quant/train.py index ede5f5d703..5fcaa9ba2f 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/train.py +++ b/model_zoo/official/cv/mobilenetv2_quant/train.py @@ -101,7 +101,8 @@ def train_on_ascend(): # convert fusion network to quantization aware network quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], - symmetric=[True, False]) + symmetric=[True, False], + one_conv_fold=False) network = quantizer.quantize(network) # get learning rate diff --git a/model_zoo/official/cv/resnet50_quant/train.py b/model_zoo/official/cv/resnet50_quant/train.py index ba7cfa6268..a5112066b9 100755 --- a/model_zoo/official/cv/resnet50_quant/train.py +++ b/model_zoo/official/cv/resnet50_quant/train.py @@ -115,7 +115,8 @@ if __name__ == '__main__': # convert fusion network to quantization aware network quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], - symmetric=[True, False]) + symmetric=[True, False], + one_conv_fold=False) net = quantizer.quantize(net) # get learning rate diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/train.py b/model_zoo/official/cv/yolov3_darknet53_quant/train.py index a4b9be26a0..1967aeac28 100644 --- a/model_zoo/official/cv/yolov3_darknet53_quant/train.py +++ b/model_zoo/official/cv/yolov3_darknet53_quant/train.py @@ -170,7 +170,8 @@ def train(): if config.quantization_aware: quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], - symmetric=[True, False]) + symmetric=[True, False], + one_conv_fold=False) network = quantizer.quantize(network) network = YoloWithLossCell(network)