|
|
|
@ -223,15 +223,15 @@ def get_cfgs(input_program):
|
|
|
|
|
|
|
|
|
|
# Find while/while_grad block pair
|
|
|
|
|
for grad_id in while_grad_sub_block_ids:
|
|
|
|
|
parent_id = pdesc.block(grad_id).parent
|
|
|
|
|
if parent_id in while_sub_block_ids:
|
|
|
|
|
while_block_id_pair.append((parent_id, grad_id))
|
|
|
|
|
while_sub_block_ids.remove(parent_id)
|
|
|
|
|
forward_id = pdesc.block(grad_id).get_forward_block_idx()
|
|
|
|
|
if forward_id in while_sub_block_ids:
|
|
|
|
|
while_block_id_pair.append((forward_id, grad_id))
|
|
|
|
|
while_sub_block_ids.remove(forward_id)
|
|
|
|
|
|
|
|
|
|
# Get while/while_grad block ops
|
|
|
|
|
for parent_id, grad_id in while_block_id_pair:
|
|
|
|
|
for forward_id, grad_id in while_block_id_pair:
|
|
|
|
|
while_block_ops = []
|
|
|
|
|
while_block = pdesc.block(parent_id)
|
|
|
|
|
while_block = pdesc.block(forward_id)
|
|
|
|
|
while_block_op_size = while_block.op_size()
|
|
|
|
|
for i in range(while_block_op_size):
|
|
|
|
|
while_block_ops.append(while_block.op(i))
|
|
|
|
@ -242,21 +242,21 @@ def get_cfgs(input_program):
|
|
|
|
|
while_block_ops.append(while_grad_block.op(i))
|
|
|
|
|
|
|
|
|
|
while_op_output = set()
|
|
|
|
|
while_op_output.update(while_op_dict[parent_id].output_arg_names())
|
|
|
|
|
while_op_output.update(while_op_dict[forward_id].output_arg_names())
|
|
|
|
|
while_op_output.update(while_op_dict[grad_id].output_arg_names())
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
|
|
|
|
|
|
|
|
|
|
# Process rest while block ops
|
|
|
|
|
for parent_id in while_sub_block_ids:
|
|
|
|
|
for forward_id in while_sub_block_ids:
|
|
|
|
|
while_block_ops = []
|
|
|
|
|
while_block = pdesc.block(parent_id)
|
|
|
|
|
while_block = pdesc.block(forward_id)
|
|
|
|
|
while_block_op_size = while_block.op_size()
|
|
|
|
|
for i in range(while_block_op_size):
|
|
|
|
|
while_block_ops.append(while_block.op(i))
|
|
|
|
|
|
|
|
|
|
while_op_output = set()
|
|
|
|
|
while_op_output.update(while_op_dict[parent_id].output_arg_names())
|
|
|
|
|
while_op_output.update(while_op_dict[forward_id].output_arg_names())
|
|
|
|
|
|
|
|
|
|
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
|
|
|
|
|
|
|
|
|
|