|
|
@ -61,6 +61,9 @@ class InferenceTranspiler(object):
|
|
|
|
raise TypeError("scope should be as Scope type or None")
|
|
|
|
raise TypeError("scope should be as Scope type or None")
|
|
|
|
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
|
|
|
|
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_mkldnn:
|
|
|
|
|
|
|
|
self._depthwise_conv_mkldnn(program)
|
|
|
|
|
|
|
|
|
|
|
|
self._fuse_batch_norm(program, place, scope)
|
|
|
|
self._fuse_batch_norm(program, place, scope)
|
|
|
|
if use_mkldnn:
|
|
|
|
if use_mkldnn:
|
|
|
|
self._fuse_conv_bias_mkldnn(program)
|
|
|
|
self._fuse_conv_bias_mkldnn(program)
|
|
|
@ -70,6 +73,31 @@ class InferenceTranspiler(object):
|
|
|
|
program) # ResNet residual block merging
|
|
|
|
program) # ResNet residual block merging
|
|
|
|
self._fuse_bn_relu_mkldnn(program)
|
|
|
|
self._fuse_bn_relu_mkldnn(program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _depthwise_conv_mkldnn(self, program):
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
Transpile the program by replacing depthwise_conv2d to conv2d for MKLDNN program.
|
|
|
|
|
|
|
|
The result is:
|
|
|
|
|
|
|
|
- before:
|
|
|
|
|
|
|
|
- any_other_op->depthwise_conv->any_other_op
|
|
|
|
|
|
|
|
- after:
|
|
|
|
|
|
|
|
- any_other_op->conv->any_other_op
|
|
|
|
|
|
|
|
:param program: program to transpile
|
|
|
|
|
|
|
|
:type program: Program
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
self.block = program.block(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
i = 0
|
|
|
|
|
|
|
|
while i < len(self.block.ops):
|
|
|
|
|
|
|
|
current_op = self.block.ops[i]
|
|
|
|
|
|
|
|
if current_op.type == 'depthwise_conv2d':
|
|
|
|
|
|
|
|
current_op.desc.set_type("conv2d")
|
|
|
|
|
|
|
|
i = i + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(luotao): use clone() method to flush the program.desc in force,
|
|
|
|
|
|
|
|
# since some large program.desc will not be flushed immediately.
|
|
|
|
|
|
|
|
# And a better solution will be considered later.
|
|
|
|
|
|
|
|
program = program.clone()
|
|
|
|
|
|
|
|
|
|
|
|
def _fuse_conv_eltwise_mkldnn(self, program):
|
|
|
|
def _fuse_conv_eltwise_mkldnn(self, program):
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
Transpile the program fusing elementwise_add into conv for MKLDNN
|
|
|
|
Transpile the program fusing elementwise_add into conv for MKLDNN
|
|
|
|