@ -43,7 +43,7 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
self . conv_output = np . ndarray ( self . conv_output_size ) . astype ( self . dtype )
self . conv_output2 = np . ndarray ( self . conv_output2_size ) . astype (
self . dtype )
self . quantized_ops = ' conv2d '
self . quantized_ops = ' conv2d ,mul '
self . variables = {
" input " : self . input ,
" filter " : self . filter ,
@ -51,6 +51,22 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
" conv_output " : self . conv_output ,
" conv_output2 " : self . conv_output2 ,
}
self . mul_input_size = [ 1 , 3 ]
self . mul_weights_size = [ 3 , 5 ]
self . mul_output_size = [ 1 , 5 ]
self . mul_input = np . random . random ( self . mul_input_size ) . astype (
self . dtype )
self . mul_weights = np . ones ( self . mul_weights_size , self . dtype )
self . mul_weights_bad = np . ones ( [ 1 , 1 ] , self . dtype )
self . mul_output = np . ndarray ( self . mul_output_size ) . astype ( self . dtype )
self . mul_output_scale = np . linspace ( 1 , 5 , num = 5 ) . astype ( self . dtype )
self . variables_mul = {
" mul_input " : self . mul_input ,
" mul_weights " : self . mul_weights ,
" mul_output " : self . mul_output ,
" mul_weights_bad " : self . mul_weights_bad
}
def prepare_program ( self , program ) :
block = program . global_block ( )
@ -92,6 +108,23 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
' fuse_brelu ' : True
} )
def prepare_program_mul ( self , program ) :
block = program . global_block ( )
for name in self . variables_mul :
block . create_var (
name = name ,
dtype = " float32 " ,
shape = self . variables_mul [ name ] . shape )
mul_op1 = block . append_op (
type = " mul " ,
inputs = {
" X " : block . var ( ' mul_input ' ) ,
" Y " : block . var ( ' mul_weights ' )
} ,
outputs = { " Out " : block . var ( ' mul_output ' ) } ,
attrs = { ' use_mkldnn ' : self . use_mkldnn } )
def remove_fuse_activation_attribute ( self , graph ) :
for op in graph . all_op_nodes ( ) :
op . op ( ) . remove_attr ( " fuse_activation " )
@ -103,11 +136,13 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
def check_graph_after_pass ( self , graph ) :
for op in graph . all_op_nodes ( ) :
self . assertTrue ( op . op ( ) . has_attr ( " fuse_activation " ) )
if op . op ( ) . has_attr ( " fuse_relu " ) and op . op ( ) . attr ( " fuse_relu " ) :
self . assertTrue ( op . op ( ) . attr ( " fuse_activation " ) == " relu " )
if op . op ( ) . has_attr ( " fuse_brelu " ) and op . op ( ) . attr ( " fuse_brelu " ) :
self . assertTrue ( op . op ( ) . attr ( " fuse_activation " ) == " relu6 " )
if op . op ( ) . type ( ) == " conv2d " :
self . assertTrue ( op . op ( ) . has_attr ( " fuse_activation " ) )
if op . op ( ) . has_attr ( " fuse_relu " ) and op . op ( ) . attr ( " fuse_relu " ) :
self . assertTrue ( op . op ( ) . attr ( " fuse_activation " ) == " relu " )
if op . op ( ) . has_attr ( " fuse_brelu " ) and op . op ( ) . attr (
" fuse_brelu " ) :
self . assertTrue ( op . op ( ) . attr ( " fuse_activation " ) == " relu6 " )
def test_quant_update_activation ( self ) :
program = fluid . Program ( )
@ -125,6 +160,39 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
graph = quant2_int8_mkldnn_pass . _update_activations ( graph )
self . check_graph_after_pass ( graph )
def test_dequantize_op_weights ( self ) :
program = fluid . Program ( )
with fluid . program_guard ( program ) :
self . prepare_program_mul ( program )
graph = IrGraph ( core . Graph ( program . desc ) , for_test = True )
for op in graph . all_op_nodes ( ) :
if op . op ( ) . type ( ) == " mul " :
op_node = op
break
qpass = Quant2Int8MkldnnPass (
self . quantized_ops ,
_scope = self . scope ,
_place = self . place ,
_core = core ,
_debug = False )
qpass . _weight_scales [ " mul_output " ] = self . mul_output_scale
param = self . scope . var ( " mul_weights " ) . get_tensor ( )
param . set ( self . variables_mul [ " mul_weights " ] , self . place )
qpass . _dequantize_op_weights ( graph , op_node , " Y " , " Out " )
assert np . allclose (
self . scope . find_var ( " mul_weights " ) . get_tensor ( ) ,
[ [ 127 , 63.5 , 42.3333 , 31.75 , 25.4 ] ,
[ 127 , 63.5 , 42.3333 , 31.75 , 25.4 ] ,
[ 127 , 63.5 , 42.3333 , 31.75 , 25.4 ] ] )
param = self . scope . var ( " mul_weights " ) . get_tensor ( )
param . set ( self . variables_mul [ " mul_weights_bad " ] , self . place )
with self . assertRaises ( ValueError ) :
qpass . _dequantize_op_weights ( graph , op_node , " Y " , " Out " )
if __name__ == ' __main__ ' :
unittest . main ( )