|
|
|
@ -145,7 +145,6 @@ class ControlFlowGraph(object):
|
|
|
|
|
if op.type() == "while" or op.type() == "while_grad":
|
|
|
|
|
continue
|
|
|
|
|
block_desc = op.block()
|
|
|
|
|
self.current_block_desc = block_desc
|
|
|
|
|
is_forward = i < self._forward_num
|
|
|
|
|
if self.pool:
|
|
|
|
|
defs_can_optimize = filter(
|
|
|
|
@ -208,17 +207,17 @@ def get_cfgs(input_program):
|
|
|
|
|
|
|
|
|
|
while_sub_block_ids = []
|
|
|
|
|
while_grad_sub_block_ids = []
|
|
|
|
|
while_op_output = set()
|
|
|
|
|
while_block_id_pair = []
|
|
|
|
|
while_op_dict = {}
|
|
|
|
|
|
|
|
|
|
for i in range(op_size):
|
|
|
|
|
op = block_desc.op(i)
|
|
|
|
|
if op.type() == "while":
|
|
|
|
|
while_sub_block_ids.append(op.attr("sub_block").id)
|
|
|
|
|
while_op_output.update(op.output_arg_names())
|
|
|
|
|
while_op_dict[op.attr("sub_block").id] = op
|
|
|
|
|
elif op.type() == "while_grad":
|
|
|
|
|
while_grad_sub_block_ids.append(op.attr("sub_block").id)
|
|
|
|
|
while_op_output.update(op.output_arg_names())
|
|
|
|
|
while_op_dict[op.attr("sub_block").id] = op
|
|
|
|
|
|
|
|
|
|
# Find while/while_grad block pair
|
|
|
|
|
for grad_id in while_grad_sub_block_ids:
|
|
|
|
@ -240,6 +239,10 @@ def get_cfgs(input_program):
|
|
|
|
|
for i in range(while_grad_block_op_size):
|
|
|
|
|
while_block_ops.append(while_grad_block.op(i))
|
|
|
|
|
|
|
|
|
|
while_op_output = set()
|
|
|
|
|
while_op_output.update(while_op_dict[parent_id].output_arg_names())
|
|
|
|
|
while_op_output.update(while_op_dict[grad_id].output_arg_names())
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
|
|
|
|
|
|
|
|
|
|
# Process rest while block ops
|
|
|
|
@ -250,9 +253,15 @@ def get_cfgs(input_program):
|
|
|
|
|
for i in range(while_block_op_size):
|
|
|
|
|
while_block_ops.append(while_block.op(i))
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size))
|
|
|
|
|
while_op_output = set()
|
|
|
|
|
while_op_output.update(while_op_dict[parent_id].output_arg_names())
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
|
|
|
|
|
|
|
|
|
|
cfgs = [ControlFlowGraph(input_program, i, j, k) for i, j, k in ops_list]
|
|
|
|
|
cfgs = [
|
|
|
|
|
ControlFlowGraph(input_program, ops, forward_num, skip_opt)
|
|
|
|
|
for ops, forward_num, skip_opt in ops_list
|
|
|
|
|
]
|
|
|
|
|
return cfgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|