|
|
|
@ -29,6 +29,8 @@ dtype_to_size = {
|
|
|
|
|
core.VarDesc.VarType.BOOL: 1
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ControlFlowGraph(object):
|
|
|
|
|
def __init__(self, Program, ops, forward_num, skip_opt):
|
|
|
|
@ -141,7 +143,7 @@ class ControlFlowGraph(object):
|
|
|
|
|
self.pool = []
|
|
|
|
|
for i in range(self.op_size):
|
|
|
|
|
op = self._ops[i]
|
|
|
|
|
if op.type() == "while" or op.type() == "while_grad":
|
|
|
|
|
if op.type() in sub_block_ops:
|
|
|
|
|
continue
|
|
|
|
|
block_desc = op.block()
|
|
|
|
|
is_forward = i < self._forward_num
|
|
|
|
@ -198,67 +200,75 @@ class ControlFlowGraph(object):
|
|
|
|
|
block_desc, var_name, is_forward).shape()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cfgs(input_program):
|
|
|
|
|
def _process_sub_block_pair(pdesc, sub_block_pair):
|
|
|
|
|
ops_list = []
|
|
|
|
|
pdesc = input_program.get_desc()
|
|
|
|
|
block_desc = pdesc.block(0)
|
|
|
|
|
op_size = block_desc.op_size()
|
|
|
|
|
# Get global block ops
|
|
|
|
|
ops_list.append(
|
|
|
|
|
([block_desc.op(i) for i in range(op_size)], op_size, set()))
|
|
|
|
|
|
|
|
|
|
while_sub_block_ids = []
|
|
|
|
|
while_grad_sub_block_ids = []
|
|
|
|
|
while_block_id_pair = []
|
|
|
|
|
while_op_dict = {}
|
|
|
|
|
for fwd_op, bwd_op in sub_block_pair:
|
|
|
|
|
sub_block_ids = []
|
|
|
|
|
grad_sub_block_ids = []
|
|
|
|
|
sub_block_id_pair = []
|
|
|
|
|
sub_op_dict = {}
|
|
|
|
|
for i in range(op_size):
|
|
|
|
|
op = block_desc.op(i)
|
|
|
|
|
if op.type() == fwd_op:
|
|
|
|
|
sub_block_ids.append(op.attr("sub_block").id)
|
|
|
|
|
sub_op_dict[op.attr("sub_block").id] = op
|
|
|
|
|
elif op.type() == bwd_op:
|
|
|
|
|
grad_sub_block_ids.append(op.attr("sub_block").id)
|
|
|
|
|
sub_op_dict[op.attr("sub_block").id] = op
|
|
|
|
|
|
|
|
|
|
for i in range(op_size):
|
|
|
|
|
op = block_desc.op(i)
|
|
|
|
|
if op.type() == "while":
|
|
|
|
|
while_sub_block_ids.append(op.attr("sub_block").id)
|
|
|
|
|
while_op_dict[op.attr("sub_block").id] = op
|
|
|
|
|
elif op.type() == "while_grad":
|
|
|
|
|
while_grad_sub_block_ids.append(op.attr("sub_block").id)
|
|
|
|
|
while_op_dict[op.attr("sub_block").id] = op
|
|
|
|
|
# Find fwd_op/bwd_op block pair
|
|
|
|
|
for grad_id in grad_sub_block_ids:
|
|
|
|
|
parent_id = pdesc.block(grad_id).parent
|
|
|
|
|
if parent_id in sub_block_ids:
|
|
|
|
|
sub_block_id_pair.append((parent_id, grad_id))
|
|
|
|
|
sub_block_ids.remove(parent_id)
|
|
|
|
|
|
|
|
|
|
# Find while/while_grad block pair
|
|
|
|
|
for grad_id in while_grad_sub_block_ids:
|
|
|
|
|
parent_id = pdesc.block(grad_id).parent
|
|
|
|
|
if parent_id in while_sub_block_ids:
|
|
|
|
|
while_block_id_pair.append((parent_id, grad_id))
|
|
|
|
|
while_sub_block_ids.remove(parent_id)
|
|
|
|
|
# Get fwd_op/bwd_op block ops
|
|
|
|
|
for parent_id, grad_id in sub_block_id_pair:
|
|
|
|
|
sub_block_ops = []
|
|
|
|
|
sub_block = pdesc.block(parent_id)
|
|
|
|
|
block_op_size = sub_block.op_size()
|
|
|
|
|
for i in range(block_op_size):
|
|
|
|
|
sub_block_ops.append(sub_block.op(i))
|
|
|
|
|
|
|
|
|
|
# Get while/while_grad block ops
|
|
|
|
|
for parent_id, grad_id in while_block_id_pair:
|
|
|
|
|
while_block_ops = []
|
|
|
|
|
while_block = pdesc.block(parent_id)
|
|
|
|
|
while_block_op_size = while_block.op_size()
|
|
|
|
|
for i in range(while_block_op_size):
|
|
|
|
|
while_block_ops.append(while_block.op(i))
|
|
|
|
|
grad_sub_block = pdesc.block(grad_id)
|
|
|
|
|
grad_sub_block_op_size = grad_sub_block.op_size()
|
|
|
|
|
for i in range(grad_sub_block_op_size):
|
|
|
|
|
sub_block_ops.append(grad_sub_block.op(i))
|
|
|
|
|
|
|
|
|
|
while_grad_block = pdesc.block(grad_id)
|
|
|
|
|
while_grad_block_op_size = while_grad_block.op_size()
|
|
|
|
|
for i in range(while_grad_block_op_size):
|
|
|
|
|
while_block_ops.append(while_grad_block.op(i))
|
|
|
|
|
sub_op_output = set()
|
|
|
|
|
sub_op_output.update(sub_op_dict[parent_id].output_arg_names())
|
|
|
|
|
sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
|
|
|
|
|
ops_list.append((sub_block_ops, block_op_size, sub_op_output))
|
|
|
|
|
|
|
|
|
|
while_op_output = set()
|
|
|
|
|
while_op_output.update(while_op_dict[parent_id].output_arg_names())
|
|
|
|
|
while_op_output.update(while_op_dict[grad_id].output_arg_names())
|
|
|
|
|
# Process rest fwd_op block ops
|
|
|
|
|
for parent_id in sub_block_ids:
|
|
|
|
|
sub_block_ops = []
|
|
|
|
|
sub_block = pdesc.block(parent_id)
|
|
|
|
|
sub_block_op_size = sub_block.op_size()
|
|
|
|
|
for i in range(sub_block_op_size):
|
|
|
|
|
sub_block_ops.append(sub_block.op(i))
|
|
|
|
|
sub_op_output = set()
|
|
|
|
|
sub_op_output.update(sub_op_dict[parent_id].output_arg_names())
|
|
|
|
|
ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
|
|
|
|
|
return ops_list
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
|
|
|
|
|
|
|
|
|
|
# Process rest while block ops
|
|
|
|
|
for parent_id in while_sub_block_ids:
|
|
|
|
|
while_block_ops = []
|
|
|
|
|
while_block = pdesc.block(parent_id)
|
|
|
|
|
while_block_op_size = while_block.op_size()
|
|
|
|
|
for i in range(while_block_op_size):
|
|
|
|
|
while_block_ops.append(while_block.op(i))
|
|
|
|
|
def _get_cfgs(input_program):
|
|
|
|
|
ops_list = []
|
|
|
|
|
pdesc = input_program.get_desc()
|
|
|
|
|
block_desc = pdesc.block(0)
|
|
|
|
|
op_size = block_desc.op_size()
|
|
|
|
|
# Get global block ops
|
|
|
|
|
ops_list.append(
|
|
|
|
|
([block_desc.op(i) for i in range(op_size)], op_size, set()))
|
|
|
|
|
|
|
|
|
|
while_op_output = set()
|
|
|
|
|
while_op_output.update(while_op_dict[parent_id].output_arg_names())
|
|
|
|
|
sub_block_pair = [("while", "while_grad"), ("parallel_do",
|
|
|
|
|
"parallel_do_grad")]
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
|
|
|
|
|
ops_list.extend(_process_sub_block_pair(pdesc, sub_block_pair))
|
|
|
|
|
|
|
|
|
|
cfgs = [
|
|
|
|
|
ControlFlowGraph(input_program, ops, forward_num, skip_opt)
|
|
|
|
@ -268,6 +278,6 @@ def get_cfgs(input_program):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def memory_optimize(input_program):
|
|
|
|
|
cfgs = get_cfgs(input_program)
|
|
|
|
|
cfgs = _get_cfgs(input_program)
|
|
|
|
|
for cfg in cfgs:
|
|
|
|
|
cfg.memory_optimize()
|
|
|
|
|