[dygraph qat] Refine saving output scale to infer program (#31784)

* Refine saving output scale to infer program
develop
cc 4 years ago committed by GitHub
parent 68497e7b39
commit 84a551380e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,7 +30,7 @@ op_real_in_out_name = {
"swish": [["X"], ["Out"]],
}
supported_quant_layers_map = {
quant_input_layers_map = {
'Conv2D': paddle.nn.Conv2D,
'Linear': paddle.nn.Linear,
'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
@ -58,8 +58,30 @@ fake_quantize_dequantize_types = [
"fake_quantize_dequantize_moving_average_abs_max"
]
out_scale_layers_list = (
paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.MaxPool2D,
paddle.nn.BatchNorm, paddle.nn.BatchNorm2D, paddle.nn.SyncBatchNorm,
paddle.nn.LeakyReLU, paddle.nn.PReLU, paddle.nn.ReLU, paddle.nn.ReLU6,
paddle.nn.Sigmoid, paddle.nn.Softmax, paddle.nn.Tanh, paddle.nn.Swish)
quant_output_layers_map = {
'Conv2D': paddle.nn.Conv2D,
'Conv2DTranspose': paddle.nn.Conv2DTranspose,
'Linear': paddle.nn.Linear,
'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
'AdaptiveMaxPool2D': paddle.nn.AdaptiveMaxPool2D,
'AvgPool2D': paddle.nn.AvgPool2D,
'MaxPool2D': paddle.nn.MaxPool2D,
'BatchNorm': paddle.nn.BatchNorm,
'BatchNorm2D': paddle.nn.BatchNorm2D,
'SyncBatchNorm': paddle.nn.SyncBatchNorm,
'ELU': paddle.nn.ELU,
'GELU': paddle.nn.GELU,
'LeakyReLU': paddle.nn.LeakyReLU,
'PReLU': paddle.nn.PReLU,
'ReLU': paddle.nn.ReLU,
'ReLU6': paddle.nn.ReLU6,
'Sigmoid': paddle.nn.Sigmoid,
'Softmax': paddle.nn.Softmax,
'Tanh': paddle.nn.Tanh,
'Swish': paddle.nn.Swish,
}
weight_op_types = [
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose",
"depthwise_conv2d_transpose"
]

@ -33,7 +33,6 @@ from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.layer import ReLU, LeakyReLU, Sigmoid, Softmax, PReLU
from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D
from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph import nn
@ -131,8 +130,8 @@ class ImperativeLenet(fluid.dygraph.Layer):
bias_attr=False),
BatchNorm2D(6),
ReLU(),
Pool2D(
pool_size=2, pool_type='max', pool_stride=2),
MaxPool2D(
kernel_size=2, stride=2),
Conv2D(
in_channels=6,
out_channels=16,
@ -357,7 +356,6 @@ class TestImperativeOutSclae(unittest.TestCase):
"diff({}) at {}, dynamic loss = {}, static loss = {}".
format(diff, i, loss_d, loss_s))
break
self.assertTrue(
np.allclose(
np.array(dynamic_loss_rec),
@ -398,10 +396,15 @@ class TestImperativeOutSclae(unittest.TestCase):
if dynamic_ops[i].has_attr("out_threshold"):
op_count += 1
self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
if dynamic_ops[i].attr("out_threshold") != static_ops[i].attr(
"out_threshold"):
_logger.info(dynamic_ops[i].attr("out_threshold"))
_logger.info(static_ops[i].attr("out_threshold"))
self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
static_ops[i].attr("out_threshold"))
self.assertTrue(op_count == 13)
_logger.info("op_cout: {}".format(op_count))
self.assertTrue(op_count == 14)
class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
@ -470,7 +473,9 @@ class TestSaveQuanztizedModelFromCheckPoint(unittest.TestCase):
self.assertTrue(dynamic_ops[i].type == static_ops[i].type)
self.assertTrue(dynamic_ops[i].attr("out_threshold") ==
static_ops[i].attr("out_threshold"))
self.assertTrue(op_count == 13)
_logger.info("op_cout: {}".format(op_count))
self.assertTrue(op_count == 14)
class TestSaveQuantizedModel_Warning(unittest.TestCase):
@ -490,8 +495,10 @@ class TestSaveQuantizedModel_Warning(unittest.TestCase):
shape=[None, 1, 28, 28], dtype='float32')
])
warning_message = "Warning: No Layer of the model while to be saved contains the out_threshold attribute, " \
"so the generated inference model would not contain the out_threshold."
warning_message = "Warning: No Layer of the model while to be " \
"saved contains the out_threshold attribute, so the " \
"generated inference model would not contain the " \
"out_threshold."
num = get_vaild_warning_num(warning_message, w)
assert num == 1

Loading…
Cancel
Save