MKLDNN conv + elementwise_add fusion: Fix transpiler integration to predict skip connection input of eltwise_add

ce
Michal Gallus 7 years ago committed by Tomasz Patejko
parent fb7a50b230
commit f0efc244c6

@ -455,11 +455,15 @@ class InferenceTranspiler(object):
:type eltwise_op: Operator :type eltwise_op: Operator
''' '''
residual_var = self.block.var(eltwise_op.input("X")[0]) eltwise_input = "X"
out_var = self.block.var(eltwise_op.output("Out")[0]) if eltwise_op.input("X")[0] == conv_op.output("Output")[0]:
filter_var = self.block.var(conv_op.input("Filter")[0]) eltwise_input = "Y"
in_var = self.block.var(conv_op.input("Input")[0])
bias_var = self.block.var(conv_op.input("Bias")[0]) residual_var = self.block.vars[eltwise_op.input(eltwise_input)[0]]
out_var = self.block.vars[eltwise_op.output("Out")[0]]
filter_var = self.block.vars[conv_op.input("Filter")[0]]
in_var = self.block.vars[conv_op.input("Input")[0]]
bias_var = self.block.vars[conv_op.input("Bias")[0]]
conv_op.set_attr("fuse_eltwise", True) conv_op.set_attr("fuse_eltwise", True)
attrs = {name: conv_op.attr(name) for name in conv_op.attr_names} attrs = {name: conv_op.attr(name) for name in conv_op.attr_names}

Loading…
Cancel
Save