|
|
|
@ -554,8 +554,6 @@ def serialize_op_decs(op_desc):
|
|
|
|
|
|
|
|
|
|
def _append_backward_ops_with_checkpoints_(
|
|
|
|
|
block, ops, target_block, no_grad_dict, grad_to_var, checkpoints):
|
|
|
|
|
|
|
|
|
|
checkpoints_name = [x.name for x in checkpoints]
|
|
|
|
|
"""
|
|
|
|
|
Create grad ops with forward ops, and insert them into given block
|
|
|
|
|
|
|
|
|
@ -569,25 +567,27 @@ def _append_backward_ops_with_checkpoints_(
|
|
|
|
|
checkpoints: variables that a user defined as checkpoint for forward recomputation
|
|
|
|
|
|
|
|
|
|
Algorithms:
|
|
|
|
|
1) go through all forward ops and induct all checkpoint vars
|
|
|
|
|
a. input variables can be deduced from forward program
|
|
|
|
|
b. input variables are checkpoints
|
|
|
|
|
c. variables that are used across segments will be held in memory
|
|
|
|
|
2) find ops between checkpoints, i.e. recompute_segments
|
|
|
|
|
1) find ops between checkpoints, i.e. recompute_segments
|
|
|
|
|
2) go through all forward ops and induct all variables that will be hold in memory
|
|
|
|
|
a. variables that are used across segments will be held in memory
|
|
|
|
|
b. output of dropout op will be held in memory
|
|
|
|
|
c. input variables will be held in memory
|
|
|
|
|
3) go through each recompute_segments, add backward ops with forward recomputation
|
|
|
|
|
a. add ops in current recompute_segment as forward recomputation ops
|
|
|
|
|
b. rename all non-checkpoint variables in recomputation ops
|
|
|
|
|
c. add sum_op to merge gradient if needed
|
|
|
|
|
d. add backward ops of current recomputation ops
|
|
|
|
|
c. add backward ops of current recomputation ops
|
|
|
|
|
d. add sum op for repetitive_outputs
|
|
|
|
|
4) remove no grad branch as it is in _remove_no_grad_branch_
|
|
|
|
|
5) Note1: all appended ops' OpRole are Backward
|
|
|
|
|
6) Note2: variables that are used across segments will be held in memory
|
|
|
|
|
7) Note3: all variables with new name should be returned so that _append_backward_vars_ can be called
|
|
|
|
|
8) Note4: current forward recomputation backpropagation does not handle programs with subblock
|
|
|
|
|
6) Note2: all variables with new name should be returned so that _append_backward_vars_ can be called
|
|
|
|
|
7) Note3: current forward recomputation backpropagation does not handle programs with subblock
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
checkpoints_name = [x.name for x in checkpoints]
|
|
|
|
|
local_block = block.program._create_block()
|
|
|
|
|
buffer_block = block.program._create_block()
|
|
|
|
|
|
|
|
|
|
# 1) find ops between checkpoints, i.e. recompute_segments
|
|
|
|
|
program_stat = ProgramStats(block, ops)
|
|
|
|
|
program_stat.build_stats()
|
|
|
|
|
segments = []
|
|
|
|
@ -622,11 +622,16 @@ def _append_backward_ops_with_checkpoints_(
|
|
|
|
|
recompute_segments = [[0, segments[0][0]]] + segments
|
|
|
|
|
else:
|
|
|
|
|
recompute_segments = segments
|
|
|
|
|
|
|
|
|
|
# 2) go through all forward ops and induct all variables that will be hold in memory
|
|
|
|
|
vars_should_be_hold = []
|
|
|
|
|
# a. variables that are used across segments will be held in memory
|
|
|
|
|
for segment in recompute_segments:
|
|
|
|
|
vars_should_be_hold.extend(
|
|
|
|
|
program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
|
|
|
|
|
# b. output of dropout op will be held in memory
|
|
|
|
|
vars_should_be_hold.extend(program_stat.get_reserved_vars())
|
|
|
|
|
# c. input variables are checkpoints
|
|
|
|
|
vars_should_be_hold.extend(program_stat.get_input_nodes())
|
|
|
|
|
vars_should_be_hold = list(set(vars_should_be_hold))
|
|
|
|
|
|
|
|
|
@ -634,6 +639,7 @@ def _append_backward_ops_with_checkpoints_(
|
|
|
|
|
grad_should_be_hold = [x + "@GRAD" for x in vars_should_be_hold]
|
|
|
|
|
vars_should_be_hold.extend(grad_should_be_hold)
|
|
|
|
|
|
|
|
|
|
# 3) go through each recompute_segments, add backward ops with forward recomputation
|
|
|
|
|
grad_op_descs = []
|
|
|
|
|
var_name_dict = {}
|
|
|
|
|
|
|
|
|
@ -641,6 +647,8 @@ def _append_backward_ops_with_checkpoints_(
|
|
|
|
|
|
|
|
|
|
max_calculated_op_position = len(ops)
|
|
|
|
|
if recompute_segments == []:
|
|
|
|
|
# if there is no recompute segment, add backward ops like
|
|
|
|
|
# _append_backward_ops_ function
|
|
|
|
|
gap_ops = ops[0:max_calculated_op_position]
|
|
|
|
|
for op in reversed(gap_ops):
|
|
|
|
|
if op.has_attr("sub_block"):
|
|
|
|
@ -686,30 +694,30 @@ def _append_backward_ops_with_checkpoints_(
|
|
|
|
|
continue
|
|
|
|
|
if name not in var_name_dict:
|
|
|
|
|
var_name_dict[name] = name + var_suffix
|
|
|
|
|
# 3.a. add ops in current recompute_segment as forward recomputation ops
|
|
|
|
|
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
|
|
|
|
|
vars_in_memory)
|
|
|
|
|
added_descs = _add_descs_to_block(ff_ops, local_block)
|
|
|
|
|
|
|
|
|
|
# rename variable names in added_descs
|
|
|
|
|
# 3.b. rename all non-checkpoint variables in recomputation ops
|
|
|
|
|
for key in var_name_dict:
|
|
|
|
|
_rename_arg_(buffer_descs, key, var_name_dict[key])
|
|
|
|
|
|
|
|
|
|
# added_descs should be in grad_op_descs because it is backward op desc
|
|
|
|
|
grad_op_descs.extend(buffer_descs)
|
|
|
|
|
|
|
|
|
|
#for op_desc in reversed(buffer_descs):
|
|
|
|
|
# 3.c. add backward ops of current recomputation ops
|
|
|
|
|
for op_desc in reversed(added_descs):
|
|
|
|
|
|
|
|
|
|
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
|
|
|
|
|
op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
|
|
|
|
|
|
|
|
|
|
for key in var_name_dict:
|
|
|
|
|
_rename_arg_(grad_op_desc, key, var_name_dict[key])
|
|
|
|
|
|
|
|
|
|
grad_op_descs.extend(grad_op_desc)
|
|
|
|
|
grad_to_var.update(op_grad_to_var)
|
|
|
|
|
|
|
|
|
|
# 3.d. add sum op for repetitive_outputs
|
|
|
|
|
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs)
|
|
|
|
|
# 4) remove no grad branch as it is in _remove_no_grad_branch_
|
|
|
|
|
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
|
|
|
|
|
no_grad_dict[block.idx])
|
|
|
|
|
added_descs = _add_descs_to_block(grad_op_descs, target_block)
|
|
|
|
|