@ -12,6 +12,7 @@
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					# See the License for the specific language governing permissions and 
 
					 
					 
					 
					# See the License for the specific language governing permissions and 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					# limitations under the License. 
 
					 
					 
					 
					# limitations under the License. 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					
 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					import  os 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					import  numpy  as  np 
 
					 
					 
					 
					import  numpy  as  np 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					from  . .  import  core 
 
					 
					 
					 
					from  . .  import  core 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					from  . . framework  import  Program 
 
					 
					 
					 
					from  . . framework  import  Program 
 
				
			 
			
		
	
	
		
		
			
				
					
						
						
						
							
								 
							 
						
					 
					 
					@ -22,7 +23,10 @@ class InferenceTranspiler:
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					    ''' 
 
					 
					 
					 
					    ''' 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					    Convert  the  fluid  program  to  optimized  inference  program . 
 
					 
					 
					 
					    Convert  the  fluid  program  to  optimized  inference  program . 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					
 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					    There  are  several  optimizations ,  only  fuse  batch  normalization  is  supported  now . 
 
					 
					 
					 
					    There  are  several  optimizations : 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					      -  fuse  convolution  and  batch  normalization 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					      -  fuse  batch  normalization  and  relu  ( MKLDNN  only ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					
 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					    Examples : 
 
					 
					 
					 
					    Examples : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					
 
					 
					 
					 
					
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -54,6 +58,51 @@ class InferenceTranspiler:
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        if  not  isinstance ( scope ,  core . Scope ) : 
 
					 
					 
					 
					        if  not  isinstance ( scope ,  core . Scope ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					            raise  TypeError ( " scope should be as Scope type or None " ) 
 
					 
					 
					 
					            raise  TypeError ( " scope should be as Scope type or None " ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        self . fuse_batch_norm ( program ,  place ,  scope ) 
 
					 
					 
					 
					        self . fuse_batch_norm ( program ,  place ,  scope ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        self . fuse_relu_mkldnn ( program ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					    def  fuse_relu_mkldnn ( self ,  program ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        ''' 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        Transpile  the  program  by  fused  relu  activation  for  MKLDNN  program . 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        Relu  activation  following  batch  norm  OP  can  be  fused  by  adding 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        : math : ` fuse_with_relu `  attribute  to  batch  norm  OP . 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        The  result  of  fuse  is : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        -  before : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					          -  batch_norm - > relu - > any_other_op 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        -  after : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					          -  batch_norm - > any_other_op 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        : param  program :  program  to  transpile 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        : type  program :  Program 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        ''' 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        use_mkldnn  =  bool ( os . getenv ( " FLAGS_use_mkldnn " ,  False ) ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        if  not  use_mkldnn : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					            return 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        self . block  =  program . block ( 0 ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        i  =  0 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        while  i  <  len ( self . block . ops )  -  1 : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					            current_op  =  self . block . ops [ i ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					            if  current_op . type  in  [ ' batch_norm ' ] : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                next_op  =  self . block . ops [ i  +  1 ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                if  next_op . type  ==  ' relu ' : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                    # modify bnorm OP to include relu 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                    current_op . set_attr ( " fuse_with_relu " ,  True ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                    # remove relu OP 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					                    self . block . remove_op ( i  +  1 ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					            i  =  i  +  1 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        self . _remove_unused_var ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        # TODO(luotao): use clone() method to flush the program.desc in force, 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        # since some large program.desc will not be flushed immediately. 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        # And a better solution will be considered later. 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					 
					 
					 
					 
					        program  =  program . clone ( ) 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					
 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					    def  fuse_batch_norm ( self ,  program ,  place ,  scope ) : 
 
					 
					 
					 
					    def  fuse_batch_norm ( self ,  program ,  place ,  scope ) : 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        ''' 
 
					 
					 
					 
					        ''' 
 
				
			 
			
		
	
	
		
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
					 
					@ -107,7 +156,7 @@ class InferenceTranspiler:
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        self . input_map  =  { }   # store the input names should be adjusted 
 
					 
					 
					 
					        self . input_map  =  { }   # store the input names should be adjusted 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					
 
					 
					 
					 
					
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        i  =  0 
 
					 
					 
					 
					        i  =  0 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					        while  i  <  len ( self . block . ops ) : 
 
					 
					 
					 
					        while  i  <  len ( self . block . ops )  -  2  : 
 
				
			 
			
				
				
			
		
	
		
		
	
		
		
			
				
					
					 
					 
					 
					            current_op  =  self . block . ops [ i ] 
 
					 
					 
					 
					            current_op  =  self . block . ops [ i ] 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					            # TODO(luotao1): consider only conv2d now. fc would be delt later. 
 
					 
					 
					 
					            # TODO(luotao1): consider only conv2d now. fc would be delt later. 
 
				
			 
			
		
	
		
		
			
				
					
					 
					 
					 
					            if  current_op . type  in  [ ' conv2d ' ] : 
 
					 
					 
					 
					            if  current_op . type  in  [ ' conv2d ' ] :