@ -375,11 +375,15 @@ class FakeQAT2MkldnnINT8PerfPass(object):
if op . name ( ) in self . _fake_quantize_types :
op_out = graph . _find_node_by_name ( op . outputs ,
op . output ( " Out " ) [ 0 ] )
self . _remove_fake_quantize ( graph , op )
next_op = op_out . outputs [ 0 ]
if next_op . name ( ) not in self . _mul_ops :
self . _remove_fake_quantize ( graph , op )
else :
quant_op = self . _transform_to_quantize_mkldnn ( graph , op )
self . _transform_to_mul_mkldnn ( graph , next_op , quant_op )
for op in graph . all_op_nodes ( ) :
if op . name ( ) in self . _fake_dequantize_types :
op_in = graph . _find_node_by_name ( op . inputs , op . input ( " X " ) [ 0 ] )
self . _remove_fake_dequantize ( graph , op )
return graph
@ -426,8 +430,6 @@ class FakeQAT2MkldnnINT8PerfPass(object):
for op in graph . all_op_nodes ( ) :
if op . name ( ) in self . _conv_ops :
self . _dequantize_conv_weights ( graph , op )
elif op . name ( ) in self . _mul_ops :
self . _dequantize_mul_weights ( graph , op )
return graph
def _dequantize_conv_weights ( self , graph , op_node ) :
@ -463,22 +465,20 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph = self . _apply_pass ( graph , ' conv_elementwise_add_mkldnn_fuse_pass ' )
graph = self . _apply_pass ( graph , ' conv_relu_mkldnn_fuse_pass ' )
graph = self . _apply_pass ( graph , ' conv_relu6_mkldnn_fuse_pass ' )
graph = self . _apply_pass ( graph , ' fc_fuse_pass ' )
return graph
def _apply_pass ( self , graph , pass_name , attrs = None , attr_values = None ) :
ir_pass = core . get_pass ( pass_name )
inference_program = graph . to_program ( )
ir_graph = core . Graph ( inference_program . desc )
ir _graph. set_not_owned ( ' __param_scope__ ' , self . _scope )
cpp_graph = graph . graph
if not cpp_graph . has ( ' __param_scope__ ' ) :
cpp _graph. set_not_owned ( ' __param_scope__ ' , self . _scope )
if attrs :
assert attr_values and len ( attrs ) == len (
attr_values
) , " Different number of pass attributes and their values. "
for attr , value in zip ( attrs , attr_values ) :
ir_pass . set ( attr , value )
ir_pass . apply ( ir_graph )
graph = IrGraph ( ir_graph , for_test = True )
ir_pass . apply ( cpp_graph )
if self . _debug :
graph . draw ( ' . ' , ' qat_fp32_ {} ' . format ( pass_name ) ,
graph . all_op_nodes ( ) )
@ -532,15 +532,46 @@ class FakeQAT2MkldnnINT8PerfPass(object):
ids . append ( op . id ( ) )
return set ( ids )
def _transform_to_quantize_mkldnn ( self , graph , op_node ) :
"""
Transform fake_quantize_xx op to quantize mkldnn op in the graph .
"""
input_var_node = graph . _find_node_by_name ( op_node . inputs ,
op_node . input ( " X " ) [ 0 ] )
output_var_node = graph . _find_node_by_name ( op_node . outputs ,
op_node . output ( " Out " ) [ 0 ] )
scale_in = self . _s8_max / self . _load_param (
self . _scope , op_node . input ( " InScale " ) [ 0 ] ) [ 0 ]
quant_op_node = graph . create_op_node (
op_type = ' quantize ' ,
attrs = {
' data_format ' : ' MKLDNNLAYOUT ' ,
' use_mkldnn ' : 1 ,
' Scale ' : scale_in ,
' is_negative_input ' : 1
} ,
inputs = { ' Input ' : input_var_node } ,
outputs = { ' Output ' : output_var_node } )
graph . link_to ( input_var_node , quant_op_node )
graph . link_to ( quant_op_node , output_var_node )
graph . safe_remove_nodes ( op_node )
return quant_op_node
def _transform_to_mul_mkldnn ( self , graph , op_node , quantize_node ) :
input_name = op_node . input ( " X " ) [ 0 ]
scale_in = quantize_node . op ( ) . attr ( " Scale " )
op_node . set_attr ( " scale_y " , [ 1.0 ] )
op_node . set_attr ( " scale_x " , scale_in )
op_node . set_attr ( " scale_out " , 1.0 )
op_node . set_attr ( " force_fp32_output " , True )
def _quantize_fp32_graph ( self , graph ) :
ir_pass = self . _core . get_pass ( ' cpu_quantize_placement_pass ' )
inference_program = graph . to_program ( )
ir_graph = self . _core . Graph ( inference_program . desc )
cpp_graph = graph . graph
ir_pass . set ( ' quantize_enabled_op_types ' , { ' conv2d ' , ' pool2d ' } )
ir_pass . set ( ' quantize_excluded_op_ids ' ,
self . _find_avg_pooling_ids ( graph ) )
ir_pass . apply ( ir_graph )
graph = IrGraph ( ir_graph , for_test = True )
ir_pass . apply ( cpp_graph )
if self . _debug :
graph . draw ( ' . ' , ' qat_int8_ {} ' . format ( ir_pass . type ( ) ) ,
graph . all_op_nodes ( ) )