|
|
|
@ -45,10 +45,11 @@ class InferenceTranspiler:
|
|
|
|
|
- conv->elementwise_add->any_other_op
|
|
|
|
|
|
|
|
|
|
The transpile stages are:
|
|
|
|
|
1. insert elementwise_add op when bias == 0, and adjust its input and output.
|
|
|
|
|
1. insert elementwise_add op when bias == 0.
|
|
|
|
|
2. fuse the batch_norm's parameters to conv and elementwise_add operators.
|
|
|
|
|
3. remove batch_norm ops and its variables which are not used in any other ops.
|
|
|
|
|
4. remove unused variables.
|
|
|
|
|
3. remove batch_norm ops which are not used in any other ops.
|
|
|
|
|
4. adjust the input of any_other_op to be the output of elementwise_add operator.
|
|
|
|
|
5. remove unused variables.
|
|
|
|
|
|
|
|
|
|
:param program: program to transpile
|
|
|
|
|
:type program: Program
|
|
|
|
@ -62,24 +63,35 @@ class InferenceTranspiler:
|
|
|
|
|
self.scope = scope
|
|
|
|
|
self.place = place
|
|
|
|
|
self.block = program.block(0)
|
|
|
|
|
self.input_map = {} # store the input names should be adjusted
|
|
|
|
|
|
|
|
|
|
i = 0
|
|
|
|
|
while i < len(self.block.ops):
|
|
|
|
|
current_op = self.block.ops[i]
|
|
|
|
|
# TODO(luotao1): consider only conv2d now. fc would be delt later.
|
|
|
|
|
if current_op.type in ['conv2d']:
|
|
|
|
|
next_op = self.block.ops[i + 1]
|
|
|
|
|
# TODO(luotao1): consider only conv2d without bias now.
|
|
|
|
|
# If conv2d with bias, the next_op.type is elementwise_add.
|
|
|
|
|
# conv2d without bias
|
|
|
|
|
if (next_op.type == 'batch_norm'):
|
|
|
|
|
# insert bias op
|
|
|
|
|
bias_op = self._insert_bias_op(i + 1, current_op, next_op)
|
|
|
|
|
# fuse batch_norm
|
|
|
|
|
self._fuse_param(current_op, next_op, bias_op)
|
|
|
|
|
self._fuse_param(current_op, next_op, bias_op, 0)
|
|
|
|
|
# remove batch_norm_op
|
|
|
|
|
self.block.remove_op(i + 2)
|
|
|
|
|
i = i + 1
|
|
|
|
|
# conv2d with bias, the next_op.type is elementwise_add
|
|
|
|
|
elif (next_op.type == 'elementwise_add'):
|
|
|
|
|
next_next_op = self.block.ops[i + 2]
|
|
|
|
|
if (next_next_op.type == 'batch_norm'):
|
|
|
|
|
# fuse batch_norm
|
|
|
|
|
self._fuse_param(current_op, next_next_op, next_op, 1)
|
|
|
|
|
# remove batch_norm_op
|
|
|
|
|
self.block.remove_op(i + 2)
|
|
|
|
|
i = i + 1
|
|
|
|
|
i = i + 1
|
|
|
|
|
|
|
|
|
|
self._adjust_input()
|
|
|
|
|
self._remove_unused_var()
|
|
|
|
|
return program
|
|
|
|
|
|
|
|
|
@ -113,7 +125,7 @@ class InferenceTranspiler:
|
|
|
|
|
attrs={"axis": 1}) # dim_start=1
|
|
|
|
|
return bias_op
|
|
|
|
|
|
|
|
|
|
def _fuse_param(self, current_op, bn_op, bias_op):
|
|
|
|
|
def _fuse_param(self, current_op, bn_op, bias_op, with_bias):
|
|
|
|
|
'''
|
|
|
|
|
fuse the batch_norm_op' parameters to current_op (conv or fc)
|
|
|
|
|
|
|
|
|
@ -123,6 +135,8 @@ class InferenceTranspiler:
|
|
|
|
|
:type bn_op: Operator
|
|
|
|
|
:param bias_op: elementwise_add operator for adding bias
|
|
|
|
|
:type bias_op: Operator
|
|
|
|
|
:param with_bias: If current operator has bias, with_bias = 1; otherwise 0.
|
|
|
|
|
:type with_bias: Int
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def _load_tensor(param_name):
|
|
|
|
@ -144,7 +158,10 @@ class InferenceTranspiler:
|
|
|
|
|
tmp = np.float32(np.divide(scale_bn, std_bn))
|
|
|
|
|
|
|
|
|
|
# add bias of batch_norm_op to conv2d
|
|
|
|
|
bias = np.zeros(bias_bn.shape)
|
|
|
|
|
if with_bias:
|
|
|
|
|
bias = _load_param(bias_op.input("Y"))
|
|
|
|
|
else:
|
|
|
|
|
bias = np.zeros(bias_bn.shape)
|
|
|
|
|
bias = np.float32(
|
|
|
|
|
np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn))
|
|
|
|
|
bias_tensor = _load_tensor(bias_op.input("Y"))
|
|
|
|
@ -159,6 +176,17 @@ class InferenceTranspiler:
|
|
|
|
|
# set the updated parameters
|
|
|
|
|
current_tensor.set(np.array(dst_param), self.place)
|
|
|
|
|
|
|
|
|
|
# collect the renamed input
|
|
|
|
|
self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]
|
|
|
|
|
|
|
|
|
|
def _adjust_input(self):
|
|
|
|
|
for i in range(len(self.block.ops)):
|
|
|
|
|
current_op = self.block.ops[i]
|
|
|
|
|
for input_arg in current_op.input_arg_names:
|
|
|
|
|
if input_arg in self.input_map:
|
|
|
|
|
current_op.rename_input(input_arg,
|
|
|
|
|
self.input_map[input_arg])
|
|
|
|
|
|
|
|
|
|
def _remove_unused_var(self):
|
|
|
|
|
'''
|
|
|
|
|
remove unused varibles in program
|
|
|
|
|