Enable matmul and cleanup in QAT2 (#23657)

revert-23830-2.0-beta
Wojciech Uss 6 years ago committed by GitHub
parent 4d0efee4f4
commit 1753860dd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,10 +15,11 @@
import numpy as np
from .... import core
from ....framework import IrGraph
from ....framework import IrNode
__all__ = ['Qat2Int8MkldnnPass']
OpRole = core.op_proto_and_checker_maker.OpRole
class Qat2Int8MkldnnPass(object):
"""
@ -62,6 +63,7 @@ class Qat2Int8MkldnnPass(object):
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
self._fc_ops = ['fc']
self._matmul_ops = ['matmul']
self._weight_scales = {}
# Collect the Input and Output sclaes from Fake QAT models
self._var_quant_scales = {}
@ -79,9 +81,9 @@ class Qat2Int8MkldnnPass(object):
graph = self._compute_weight_scales(graph)
graph = self._update_relu_output_scales(graph)
graph = self._propagate_scales(graph)
graph = self._set_dummy_fc_out_scales(graph)
graph = self._set_dummy_out_scales(graph)
graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
graph = self._cleanup(graph)
return graph
def apply_fp32(self, graph):
@ -92,7 +94,7 @@ class Qat2Int8MkldnnPass(object):
graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
graph = self._cleanup(graph)
return graph
def _convert_scale2tensor(self, scale):
@ -176,23 +178,29 @@ class Qat2Int8MkldnnPass(object):
return graph
def _set_dummy_fc_out_scales(self, graph):
def _set_dummy_out_scales(self, graph):
'''
For the output tensors of FC that do not have an assigned scale,
For the output tensors of fc, conv2d and matmul ops that do not have an assigned scale,
assign a dummy scale (same scale as input), so that the quantize pass
won't fail. In the end these scales aren't used, since FCs that
won't fail. In the end these scales aren't used, since the ops that
have an unassigend output scale will have a force_fp32_output attr
set to True.
'''
def _set_scale(op, op_types, input_names, output_name):
scales = self._var_quant_scales
should_set = op.name() in op_types \
and op.output(output_name)[0] not in scales \
and all(op.input(input_name)[0] in scales for input_name in input_names)
if should_set:
output_var_name = op.output(output_name)[0]
input_var_name = op.input(input_names[0])[0]
scales[output_var_name] = scales[input_var_name]
for op in graph.all_op_nodes():
if op.name() in self._fc_ops:
input_name = op.input("Input")[0]
output_name = op.output("Out")[0]
if input_name in self._var_quant_scales and \
output_name not in self._var_quant_scales:
# use input scale as a "dummy" scale
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
_set_scale(op, self._conv_ops, ["Input"], "Output")
_set_scale(op, self._fc_ops, ["Input"], "Out")
_set_scale(op, self._matmul_ops, ["X", "Y"], "Out")
return graph
@ -358,6 +366,15 @@ class Qat2Int8MkldnnPass(object):
self._remove_unused_var_nodes(graph)
return graph
def _cleanup(self, graph):
# remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
# make some MKL-DNN ops working inplace
graph = self._apply_pass(graph, 'mkldnn_inplace_pass')
graph = self._remove_unused_var_nodes(graph)
graph = self._set_op_role_forward(graph)
return graph
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_op_nodes()
@ -376,8 +393,14 @@ class Qat2Int8MkldnnPass(object):
graph.safe_remove_nodes(all_unused_vars)
return graph
def _set_op_role_forward(self, graph):
ops = graph.all_op_nodes()
for op in ops:
op.set_attr("op_role", OpRole.Forward)
return graph
def _compute_weight_scales(self, graph):
def _compute_var_scales(ops, out_name, w_name, axis):
def _compute_var_scales(ops, w_name, axis):
for op in graph.all_op_nodes():
if op.op().type() in ops:
weight_var_name = op.input(w_name)[0]
@ -394,8 +417,8 @@ class Qat2Int8MkldnnPass(object):
self._var_quant_scales[weight_var_name] = (use_unsigned_int,
lod_tensor)
_compute_var_scales(self._conv_ops, "Output", "Filter", axis=1)
_compute_var_scales(self._fc_ops, "Out", "W", axis=0)
_compute_var_scales(self._conv_ops, "Filter", axis=1)
_compute_var_scales(self._fc_ops, "W", axis=0)
return graph
def _find_avg_pooling_ids(self, graph):

@ -231,7 +231,7 @@ if(LINUX AND WITH_MKLDNN)
### QATv2 for NLP
set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2")
set(QAT2_NLP_QUANTIZED_OPS "fc,reshape2,transpose2,matmul")
set(NLP_DATA_ARCHIVE "Ernie_dataset.tar.gz")
set(NLP_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie_dataset")

Loading…
Cancel
Save