@ -321,10 +321,11 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph = self . _gather_scales ( graph )
graph = self . _remove_fake_ops ( graph )
graph = self . _update_pooling_scales ( graph )
graph = self . _dequantize_weights ( graph )
graph = self . _optimize_fp32_graph ( graph )
graph = self . _compute_weight_scales ( graph )
graph = self . _update_conv_relu_scales ( graph )
graph = self . _update_pooling_scales ( graph )
graph = self . _quantize_fp32_graph ( graph )
graph = self . _remove_unused_var_nodes ( graph )
return graph
@ -350,6 +351,8 @@ class FakeQAT2MkldnnINT8PerfPass(object):
use_unsigned_int = False
self . _var_quant_scales [ input_name ] = ( use_unsigned_int ,
lod_tensor )
self . _var_quant_scales [ scale_name . replace ( " .scale " , " " ) ] = (
use_unsigned_int , lod_tensor )
if op . name ( ) in self . _fake_dequantize_types :
input_name = op . input ( " X " ) [ 0 ]
@ -378,13 +381,13 @@ class FakeQAT2MkldnnINT8PerfPass(object):
next_op = op_out . outputs [ 0 ]
if next_op . name ( ) not in self . _mul_ops :
self . _remove_fake_quantize ( graph , op )
else :
quant_op = self . _transform_to_quantize_mkldnn ( graph , op )
self . _transform_to_mul_mkldnn ( graph , next_op , quant_op )
for op in graph . all_op_nodes ( ) :
if op . name ( ) in self . _fake_dequantize_types :
self . _remove_fake_dequantize ( graph , op )
op_in = graph . _find_node_by_name ( op . inputs , op . input ( " X " ) [ 0 ] )
prev_op = op_in . inputs [ 0 ]
if prev_op . name ( ) not in self . _mul_ops :
self . _remove_fake_dequantize ( graph , op )
return graph
def _remove_fake_quantize ( self , graph , op ) :
@ -530,7 +533,7 @@ class FakeQAT2MkldnnINT8PerfPass(object):
if op . name ( ) in self . _pool_ops :
if op . op ( ) . attr ( " pooling_type " ) == " avg " :
ids . append ( op . id ( ) )
return set ( ids )
return set ( ids ) if len ( ids ) else set ( [ - 1 ] )
def _transform_to_quantize_mkldnn ( self , graph , op_node ) :
"""
@ -557,13 +560,16 @@ class FakeQAT2MkldnnINT8PerfPass(object):
graph . safe_remove_nodes ( op_node )
return quant_op_node
def _transform_to_mul_mkldnn ( self , graph , op_node , quantize_node ) :
input_name = op_node . input ( " X " ) [ 0 ]
scale_in = quantize_node . op ( ) . attr ( " Scale " )
op_node . set_attr ( " scale_y " , [ 1.0 ] )
op_node . set_attr ( " scale_x " , scale_in )
op_node . set_attr ( " scale_out " , 1.0 )
op_node . set_attr ( " force_fp32_output " , True )
def _update_conv_relu_scales ( self , graph ) :
for op in graph . all_op_nodes ( ) :
if op . name ( ) in self . _conv_ops :
out_name = op . output ( " Output " ) [ 0 ]
if out_name in self . _var_quant_scales and \
op . op ( ) . attr ( " fuse_activation " ) == ' relu ' and \
op . op ( ) . attr ( " fuse_residual_connection " ) == False :
_ , tensor = self . _var_quant_scales [ out_name ]
self . _var_quant_scales [ out_name ] = ( True , tensor )
return graph
def _quantize_fp32_graph ( self , graph ) :
ir_pass = self . _core . get_pass ( ' cpu_quantize_placement_pass ' )