|
|
|
|
@ -61,30 +61,26 @@ class InferenceTranspiler:
|
|
|
|
|
'''
|
|
|
|
|
self.scope = scope
|
|
|
|
|
self.place = place
|
|
|
|
|
self.block_desc = program.get_desc().block(0)
|
|
|
|
|
self.block = program.block(0)
|
|
|
|
|
i = 0
|
|
|
|
|
while i < self.block_desc.op_size():
|
|
|
|
|
current_op = self.block_desc.op(i)
|
|
|
|
|
while i < len(self.block.ops):
|
|
|
|
|
current_op = self.block.ops[i]
|
|
|
|
|
# TODO(luotao1): consider only conv2d now. fc would be delt later.
|
|
|
|
|
if current_op.type() in ['conv2d']:
|
|
|
|
|
next_op = self.block_desc.op(i + 1)
|
|
|
|
|
if current_op.type in ['conv2d']:
|
|
|
|
|
next_op = self.block.ops[i + 1]
|
|
|
|
|
# TODO(luotao1): consider only conv2d without bias now.
|
|
|
|
|
# If conv2d with bias, the next_op.type is elementwise_add.
|
|
|
|
|
if (next_op.type() == 'batch_norm'):
|
|
|
|
|
if (next_op.type == 'batch_norm'):
|
|
|
|
|
# insert bias op
|
|
|
|
|
bias_op = self._insert_bias_op(i + 1, current_op, next_op)
|
|
|
|
|
program.sync_with_cpp()
|
|
|
|
|
# fuse batch_norm
|
|
|
|
|
self._fuse_param(current_op, next_op, bias_op)
|
|
|
|
|
# remove batch_norm_op
|
|
|
|
|
self.block_desc.remove_op(i + 2, i + 3)
|
|
|
|
|
program.sync_with_cpp()
|
|
|
|
|
self.block.remove_op(i + 2)
|
|
|
|
|
i = i + 1
|
|
|
|
|
i = i + 1
|
|
|
|
|
|
|
|
|
|
self._remove_unused_var()
|
|
|
|
|
program.sync_with_cpp()
|
|
|
|
|
|
|
|
|
|
return program
|
|
|
|
|
|
|
|
|
|
# ====================== private transpiler functions =====================
|
|
|
|
|
@ -102,14 +98,19 @@ class InferenceTranspiler:
|
|
|
|
|
:return: bias_op
|
|
|
|
|
:rtype: Operator
|
|
|
|
|
'''
|
|
|
|
|
bias_op = self.block_desc.insert_op(index)
|
|
|
|
|
bias_op.set_type("elementwise_add")
|
|
|
|
|
# The input of bias_op is current_op's output and Bias of bn_op
|
|
|
|
|
# The output of bias_op is bn_op's output
|
|
|
|
|
bias_op.set_input("X", current_op.output("Output"))
|
|
|
|
|
bias_op.set_input("Y", bn_op.input("Bias"))
|
|
|
|
|
bias_op.set_output("Out", bn_op.output("Y"))
|
|
|
|
|
bias_op.set_attr('axis', 1) # dim_start=1
|
|
|
|
|
x_var = self.block.var(current_op.output("Output")[0])
|
|
|
|
|
y_var = self.block.var(bn_op.input("Bias")[0])
|
|
|
|
|
out_var = self.block.var(bn_op.output("Y")[0])
|
|
|
|
|
|
|
|
|
|
bias_op = self.block.insert_op(
|
|
|
|
|
index,
|
|
|
|
|
type="elementwise_add",
|
|
|
|
|
inputs={"X": x_var,
|
|
|
|
|
"Y": y_var},
|
|
|
|
|
outputs={"Out": out_var},
|
|
|
|
|
attrs={"axis": 1}) # dim_start=1
|
|
|
|
|
return bias_op
|
|
|
|
|
|
|
|
|
|
def _fuse_param(self, current_op, bn_op, bias_op):
|
|
|
|
|
@ -160,15 +161,15 @@ class InferenceTranspiler:
|
|
|
|
|
|
|
|
|
|
def _remove_unused_var(self):
|
|
|
|
|
'''
|
|
|
|
|
remove unused varibles in program desc
|
|
|
|
|
remove unused varibles in program
|
|
|
|
|
'''
|
|
|
|
|
args = []
|
|
|
|
|
for i in xrange(0, self.block_desc.op_size()):
|
|
|
|
|
current_op = self.block_desc.op(i)
|
|
|
|
|
args += current_op.input_arg_names()
|
|
|
|
|
args += current_op.output_arg_names()
|
|
|
|
|
for i in range(len(self.block.ops)):
|
|
|
|
|
current_op = self.block.ops[i]
|
|
|
|
|
args += current_op.input_arg_names
|
|
|
|
|
args += current_op.output_arg_names
|
|
|
|
|
args = list(set(args)) # unique the input and output arguments
|
|
|
|
|
|
|
|
|
|
for var in self.block_desc.all_vars():
|
|
|
|
|
if var.name() not in args:
|
|
|
|
|
self.block_desc.remove_var(var.name())
|
|
|
|
|
for var in self.block.vars.keys():
|
|
|
|
|
if var not in args:
|
|
|
|
|
self.block.remove_var(var)
|
|
|
|
|
|