@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from . . import core
from . . framework import Program
@ -22,7 +23,10 @@ class InferenceTranspiler:
'''
Convert the fluid program to optimized inference program .
There are several optimizations , only fuse batch normalization is supported now .
There are several optimizations :
- fuse convolution and batch normalization
- fuse batch normalization and relu ( MKLDNN only )
Examples :
@ -54,6 +58,51 @@ class InferenceTranspiler:
if not isinstance ( scope , core . Scope ) :
raise TypeError ( " scope should be as Scope type or None " )
self . fuse_batch_norm ( program , place , scope )
self . fuse_relu_mkldnn ( program )
def fuse_relu_mkldnn ( self , program ) :
'''
Transpile the program by fused relu activation for MKLDNN program .
Relu activation following batch norm OP can be fused by adding
: math : ` fuse_with_relu ` attribute to batch norm OP .
The result of fuse is :
- before :
- batch_norm - > relu - > any_other_op
- after :
- batch_norm - > any_other_op
: param program : program to transpile
: type program : Program
'''
use_mkldnn = bool ( os . getenv ( " FLAGS_use_mkldnn " , False ) )
if not use_mkldnn :
return
self . block = program . block ( 0 )
i = 0
while i < len ( self . block . ops ) - 1 :
current_op = self . block . ops [ i ]
if current_op . type in [ ' batch_norm ' ] :
next_op = self . block . ops [ i + 1 ]
if next_op . type == ' relu ' :
# modify bnorm OP to include relu
current_op . set_attr ( " fuse_with_relu " , True )
# remove relu OP
self . block . remove_op ( i + 1 )
i = i + 1
self . _remove_unused_var ( )
# TODO(luotao): use clone() method to flush the program.desc in force,
# since some large program.desc will not be flushed immediately.
# And a better solution will be considered later.
program = program . clone ( )
def fuse_batch_norm ( self , program , place , scope ) :
'''
@ -107,7 +156,7 @@ class InferenceTranspiler:
self . input_map = { } # store the input names should be adjusted
i = 0
while i < len ( self . block . ops ) :
while i < len ( self . block . ops ) - 2 :
current_op = self . block . ops [ i ]
# TODO(luotao1): consider only conv2d now. fc would be delt later.
if current_op . type in [ ' conv2d ' ] :