|
|
|
@ -92,7 +92,8 @@ class InferenceTranspiler(object):
|
|
|
|
|
if current_op.type in ['conv2d']:
|
|
|
|
|
next_op = self.block.ops[i + 1]
|
|
|
|
|
if next_op.type == 'elementwise_add':
|
|
|
|
|
self._fuse_conv_eltwise(current_op, next_op)
|
|
|
|
|
self._fuse_conv_eltwise(i, current_op, next_op)
|
|
|
|
|
self.block._remove_op(i + 1) # Remove old conv
|
|
|
|
|
self.block._remove_op(i + 1) # Remove elementwise_add
|
|
|
|
|
i = i + 1
|
|
|
|
|
self._adjust_input()
|
|
|
|
@ -444,7 +445,7 @@ class InferenceTranspiler(object):
|
|
|
|
|
outputs={"Output": out_var},
|
|
|
|
|
attrs=attrs)
|
|
|
|
|
|
|
|
|
|
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
|
|
|
|
|
def _fuse_conv_eltwise(self, index, conv_op, eltwise_op):
|
|
|
|
|
'''
|
|
|
|
|
fuse the conv op with elementwise_add
|
|
|
|
|
|
|
|
|
@ -454,9 +455,26 @@ class InferenceTranspiler(object):
|
|
|
|
|
:type eltwise_op: Operator
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
conv_op._set_attr("fuse_eltwise", True)
|
|
|
|
|
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
|
|
|
|
|
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
|
|
|
|
|
residual_var = self.block.var(eltwise_op.input("X")[0])
|
|
|
|
|
out_var = self.block.var(eltwise_op.output("Out")[0])
|
|
|
|
|
filter_var = self.block.var(conv_op.input("Filter")[0])
|
|
|
|
|
in_var = self.block.var(conv_op.input("Input")[0])
|
|
|
|
|
bias_var = self.block.var(conv_op.input("Bias")[0])
|
|
|
|
|
|
|
|
|
|
conv_op.set_attr("fuse_eltwise", True)
|
|
|
|
|
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
|
|
|
|
|
|
|
|
|
|
self.block._insert_op(
|
|
|
|
|
index,
|
|
|
|
|
type="conv2d",
|
|
|
|
|
inputs={
|
|
|
|
|
"Input": in_var,
|
|
|
|
|
"Filter": filter_var,
|
|
|
|
|
"Bias": bias_var,
|
|
|
|
|
"ResidualData": residual_var
|
|
|
|
|
},
|
|
|
|
|
outputs={"Output": out_var},
|
|
|
|
|
attrs=attrs)
|
|
|
|
|
|
|
|
|
|
def _adjust_input(self):
|
|
|
|
|
for i in range(len(self.block.ops)):
|
|
|
|
|