@ -102,25 +102,26 @@ class TestConv2dInt8Op(TestConv2dOp):
output1 = conv2d_forward_refer (
output1 = conv2d_forward_refer (
input . astype ( np . int32 ) , filter_int , self . groups ,
input . astype ( np . int32 ) , filter_int , self . groups ,
conv2d_param ) . astype ( np . float32 )
conv2d_param ) . astype ( np . float32 )
output1_tmp = np . round ( output1 * (
self . scale_out / ( self . scale_in * self . scale_weights [ 0 ] ) ) )
if self . fuse_residual :
if self . fuse_residual :
input_residual = np . random . randint (
input_residual = np . random . randint (
0 , 10 , self . input_residual_size ) . astype ( self . srctype )
0 , 10 , self . input_residual_size ) . astype ( self . srctype )
output_tmp = np . round ( output1 * ( self . scale_out / (
output_tmp _res = np . round ( output1 * ( self . scale_out / (
self . scale_in * self . scale_weights [ 0 ] ) ) + format_reorder (
self . scale_in * self . scale_weights [ 0 ] ) ) + format_reorder (
input_residual , self . input_residual_size ) . astype (
input_residual , self . input_residual_size ) . astype (
np . int32 ) * ( self . scale_out / self . scale_in_eltwise
np . int32 ) * ( self . scale_out / self . scale_in_eltwise
) )
) )
output_tmp2 = np . round ( output1 * (
self . scale_out / ( self . scale_in * self . scale_weights [ 0 ] ) ) )
if self . fuse_relu :
if self . fuse_relu :
output = np . maximum ( output_tmp , 0 ) . astype ( self . dsttype )
output = np . maximum ( output_tmp_res , 0 ) . astype ( self . dsttype )
else :
else :
output = output_tmp . astype ( self . dsttype )
output = output_tmp _res . astype ( self . dsttype )
else :
else :
if self . fuse_relu :
if self . fuse_relu :
output = np . maximum ( output _tmp2 , 0 ) . astype ( self . dsttype )
output = np . maximum ( output 1 _tmp, 0 ) . astype ( self . dsttype )
else :
else :
output = output _tmp2 . astype ( self . dsttype )
output = output 1 _tmp. astype ( self . dsttype )
self . inputs = {
self . inputs = {
' Input ' :
' Input ' :
@ -265,11 +266,9 @@ def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual):
self . srctype = input_dt
self . srctype = input_dt
self . dsttype = np . uint8 if fuse_relu else np . int8
self . dsttype = np . uint8 if fuse_relu else np . int8
def init_fuse_relu ( self ) :
self . fuse_relu = fuse_relu
self . fuse_relu = fuse_relu
def init_fuse_residual ( self ) :
self . fuse_residual = fuse_residual
self . fuse_residual = fuse_residual
def create_test_int8_class ( parent ) :
def create_test_int8_class ( parent ) :