MKLDNN conv + elementwise_add fusion: Fix output_data to point to the right tensor, also fix transpiler integration

ce
Michal Gallus 7 years ago committed by Tomasz Patejko
parent efd76614fb
commit f688197182

@ -399,8 +399,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Output and elementwise parameter need to have the "
"same dimension sizes");
output_data = output->mutable_data<T>(ctx.GetPlace());
output->ShareDataWith(*residual_param);
output_data = output->mutable_data<T>(ctx.GetPlace());
} else {
output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());

@ -92,7 +92,8 @@ class InferenceTranspiler(object):
if current_op.type in ['conv2d']:
next_op = self.block.ops[i + 1]
if next_op.type == 'elementwise_add':
self._fuse_conv_eltwise(current_op, next_op)
self._fuse_conv_eltwise(i, current_op, next_op)
self.block._remove_op(i + 1) # Remove old conv
self.block._remove_op(i + 1) # Remove elementwise_add
i = i + 1
self._adjust_input()
@ -444,7 +445,7 @@ class InferenceTranspiler(object):
outputs={"Output": out_var},
attrs=attrs)
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
def _fuse_conv_eltwise(self, index, conv_op, eltwise_op):
'''
fuse the conv op with elementwise_add
@ -454,9 +455,26 @@ class InferenceTranspiler(object):
:type eltwise_op: Operator
'''
conv_op._set_attr("fuse_eltwise", True)
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
residual_var = self.block.var(eltwise_op.input("X")[0])
out_var = self.block.var(eltwise_op.output("Out")[0])
filter_var = self.block.var(conv_op.input("Filter")[0])
in_var = self.block.var(conv_op.input("Input")[0])
bias_var = self.block.var(conv_op.input("Bias")[0])
conv_op.set_attr("fuse_eltwise", True)
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}
self.block._insert_op(
index,
type="conv2d",
inputs={
"Input": in_var,
"Filter": filter_var,
"Bias": bias_var,
"ResidualData": residual_var
},
outputs={"Output": out_var},
attrs=attrs)
def _adjust_input(self):
for i in range(len(self.block.ops)):

Loading…
Cancel
Save