|
|
|
@ -31,7 +31,7 @@ dtype_to_size = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlFlowGraph(object):
|
|
|
|
|
def __init__(self, Program, ops, forward_num):
|
|
|
|
|
def __init__(self, Program, ops, forward_num, skip_opt):
|
|
|
|
|
self._program = Program
|
|
|
|
|
self._ops = ops
|
|
|
|
|
self._forward_num = forward_num
|
|
|
|
@ -41,6 +41,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
self._defs = defaultdict(set)
|
|
|
|
|
self._live_in = defaultdict(set)
|
|
|
|
|
self._live_out = defaultdict(set)
|
|
|
|
|
self._skip_opt = skip_opt
|
|
|
|
|
|
|
|
|
|
def _add_connections(self, connections):
|
|
|
|
|
for node1, node2 in connections:
|
|
|
|
@ -130,6 +131,10 @@ class ControlFlowGraph(object):
|
|
|
|
|
block_desc, x,
|
|
|
|
|
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
|
|
|
|
|
return False
|
|
|
|
|
if x in self._skip_opt:
|
|
|
|
|
return False
|
|
|
|
|
if not self._find_var(block_desc, x, is_forward).shape():
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
self._build_graph()
|
|
|
|
@ -140,6 +145,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
if op.type() == "while" or op.type() == "while_grad":
|
|
|
|
|
continue
|
|
|
|
|
block_desc = op.block()
|
|
|
|
|
self.current_block_desc = block_desc
|
|
|
|
|
is_forward = i < self._forward_num
|
|
|
|
|
if self.pool:
|
|
|
|
|
defs_can_optimize = filter(
|
|
|
|
@ -197,28 +203,32 @@ def get_cfgs(input_program):
|
|
|
|
|
block_desc = pdesc.block(0)
|
|
|
|
|
op_size = block_desc.op_size()
|
|
|
|
|
# Get global block ops
|
|
|
|
|
ops_list.append(([block_desc.op(i) for i in range(op_size)], op_size))
|
|
|
|
|
ops_list.append(
|
|
|
|
|
([block_desc.op(i) for i in range(op_size)], op_size, set()))
|
|
|
|
|
|
|
|
|
|
while_sub_block_ids = []
|
|
|
|
|
while_grad_sub_block_ids = []
|
|
|
|
|
while_pair = []
|
|
|
|
|
while_op_output = set()
|
|
|
|
|
while_block_id_pair = []
|
|
|
|
|
|
|
|
|
|
for i in range(op_size):
|
|
|
|
|
op = block_desc.op(i)
|
|
|
|
|
if op.type() == "while":
|
|
|
|
|
while_sub_block_ids.append(op.attr("sub_block").id)
|
|
|
|
|
while_op_output.update(op.output_arg_names())
|
|
|
|
|
elif op.type() == "while_grad":
|
|
|
|
|
while_grad_sub_block_ids.append(op.attr("sub_block").id)
|
|
|
|
|
while_op_output.update(op.output_arg_names())
|
|
|
|
|
|
|
|
|
|
# Find while/while_grad block pair
|
|
|
|
|
for grad_id in while_grad_sub_block_ids:
|
|
|
|
|
parent_id = pdesc.block(grad_id).parent
|
|
|
|
|
if parent_id in while_sub_block_ids:
|
|
|
|
|
while_pair.append((parent_id, grad_id))
|
|
|
|
|
while_block_id_pair.append((parent_id, grad_id))
|
|
|
|
|
while_sub_block_ids.remove(parent_id)
|
|
|
|
|
|
|
|
|
|
# Get while/while_grad block ops
|
|
|
|
|
for parent_id, grad_id in while_pair:
|
|
|
|
|
for parent_id, grad_id in while_block_id_pair:
|
|
|
|
|
while_block_ops = []
|
|
|
|
|
while_block = pdesc.block(parent_id)
|
|
|
|
|
while_block_op_size = while_block.op_size()
|
|
|
|
@ -230,7 +240,7 @@ def get_cfgs(input_program):
|
|
|
|
|
for i in range(while_grad_block_op_size):
|
|
|
|
|
while_block_ops.append(while_grad_block.op(i))
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size))
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
|
|
|
|
|
|
|
|
|
|
# Process rest while block ops
|
|
|
|
|
for parent_id in while_sub_block_ids:
|
|
|
|
@ -242,7 +252,7 @@ def get_cfgs(input_program):
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size))
|
|
|
|
|
|
|
|
|
|
cfgs = [ControlFlowGraph(input_program, i, j) for i, j in ops_list]
|
|
|
|
|
cfgs = [ControlFlowGraph(input_program, i, j, k) for i, j, k in ops_list]
|
|
|
|
|
return cfgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|