|
|
|
@ -247,6 +247,125 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
|
|
|
|
|
return op_descs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _find_not_need_ops(grad_op_descs, forward_ops, input_grad_names_set):
|
|
|
|
|
"""
|
|
|
|
|
Pruning Program with Structural Analysis Method of Computational Graph.
|
|
|
|
|
The nodes of the computational graph composed of backward OPS should be
|
|
|
|
|
interconnected. If there are unconnected sub-graphs in the computational graph,
|
|
|
|
|
these sub-graphs should be cut off.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
grad_op_descs(list[core.OpDesc]): The candidate backward OpDescs.
|
|
|
|
|
forward_ops(list[Operator]): The forward ops.
|
|
|
|
|
input_grad_names_set(set): this set is used to store the gradients' name
|
|
|
|
|
which is generated by backward ops, and input_grad_names_set can help
|
|
|
|
|
to prune the unnecessary backward ops.
|
|
|
|
|
|
|
|
|
|
Return:
|
|
|
|
|
(list[core.OpDesc]): A list of OpDescs which should be pruned.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
class Var(object):
|
|
|
|
|
def __init__(self, var_name):
|
|
|
|
|
self.var_name = var_name
|
|
|
|
|
self.gen_op = None
|
|
|
|
|
self.pendding_ops = []
|
|
|
|
|
|
|
|
|
|
def set_gen_op(self, gen_op):
|
|
|
|
|
assert isinstance(gen_op, Op)
|
|
|
|
|
assert self.gen_op is None
|
|
|
|
|
self.gen_op = gen_op
|
|
|
|
|
|
|
|
|
|
def add_pending_op(self, op):
|
|
|
|
|
assert isinstance(op, Op)
|
|
|
|
|
self.pendding_ops.append(op)
|
|
|
|
|
|
|
|
|
|
class Op(object):
|
|
|
|
|
def __init__(self, op_desc):
|
|
|
|
|
self.op_desc = op_desc
|
|
|
|
|
self.inputs = []
|
|
|
|
|
self.outputs = []
|
|
|
|
|
|
|
|
|
|
def insert_input(self, var):
|
|
|
|
|
assert isinstance(var, Var)
|
|
|
|
|
self.inputs.append(var)
|
|
|
|
|
|
|
|
|
|
def insert_output(self, var):
|
|
|
|
|
assert isinstance(var, Var)
|
|
|
|
|
self.outputs.append(var)
|
|
|
|
|
|
|
|
|
|
var_versions = dict()
|
|
|
|
|
|
|
|
|
|
def _create_node(name):
|
|
|
|
|
if name not in var_versions.keys():
|
|
|
|
|
var_versions[name] = [Var(name)]
|
|
|
|
|
else:
|
|
|
|
|
var_versions[name].append(Var(name))
|
|
|
|
|
return var_versions[name][-1]
|
|
|
|
|
|
|
|
|
|
def _create_or_get_last_version_node(name):
|
|
|
|
|
if name not in var_versions.keys():
|
|
|
|
|
var_versions[name] = [Var(name)]
|
|
|
|
|
return var_versions[name][-1]
|
|
|
|
|
|
|
|
|
|
def _create_op_node(op_desc):
|
|
|
|
|
op_node = Op(op_desc)
|
|
|
|
|
for input in op_desc.input_arg_names():
|
|
|
|
|
var = _create_or_get_last_version_node(name=input)
|
|
|
|
|
var.add_pending_op(op_node)
|
|
|
|
|
op_node.insert_input(var)
|
|
|
|
|
for output in op_desc.output_arg_names():
|
|
|
|
|
var = _create_node(name=output)
|
|
|
|
|
var.set_gen_op(op_node)
|
|
|
|
|
op_node.insert_output(var)
|
|
|
|
|
return op_node
|
|
|
|
|
|
|
|
|
|
# Record the forward vars
|
|
|
|
|
forward_vars_set = set() if input_grad_names_set is None else set(
|
|
|
|
|
input_grad_names_set)
|
|
|
|
|
for op in forward_ops:
|
|
|
|
|
forward_vars_set.update(op.desc.input_arg_names())
|
|
|
|
|
forward_vars_set.update(op.desc.output_arg_names())
|
|
|
|
|
|
|
|
|
|
# Record the vars which are created during backward and is not generated by op.
|
|
|
|
|
backward_vars_set = set()
|
|
|
|
|
# special_op_nodes is the candidate sub-graph head node.
|
|
|
|
|
special_op_nodes = set()
|
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
|
input_set = set(op_desc.input_arg_names())
|
|
|
|
|
# The new_vars are created during backward and is not generated by op.
|
|
|
|
|
new_vars = input_set - forward_vars_set - backward_vars_set
|
|
|
|
|
backward_vars_set.update(op_desc.output_arg_names())
|
|
|
|
|
|
|
|
|
|
op_node = _create_op_node(op_desc)
|
|
|
|
|
if len(new_vars) == len(input_set):
|
|
|
|
|
special_op_nodes.add(op_node)
|
|
|
|
|
|
|
|
|
|
not_need_op_descs = []
|
|
|
|
|
# Start traversing all candidate sub-graph headers to check whether
|
|
|
|
|
# they are connected to backward computational graphs, and if they are
|
|
|
|
|
# not, list them in not_need_op_descs
|
|
|
|
|
for special_op_node in special_op_nodes:
|
|
|
|
|
op_list = [special_op_node]
|
|
|
|
|
ready_vars = set(special_op_node.inputs)
|
|
|
|
|
remove_ops = True
|
|
|
|
|
candidate_ops = [special_op_node]
|
|
|
|
|
while len(candidate_ops) > 0:
|
|
|
|
|
op_node = candidate_ops.pop(0)
|
|
|
|
|
if _all_in_set_(op_node.inputs, ready_vars):
|
|
|
|
|
for out_var in op_node.outputs:
|
|
|
|
|
candidate_ops.extend(out_var.pendding_ops)
|
|
|
|
|
op_list.extend(out_var.pendding_ops)
|
|
|
|
|
ready_vars.update(op_node.outputs)
|
|
|
|
|
else:
|
|
|
|
|
remove_ops = False
|
|
|
|
|
break
|
|
|
|
|
if remove_ops:
|
|
|
|
|
not_need_op_descs.extend([node.op_desc for node in op_list])
|
|
|
|
|
|
|
|
|
|
return set(not_need_op_descs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .proto import framework_pb2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -276,7 +395,10 @@ def _append_backward_ops_(block,
|
|
|
|
|
grad_to_var(dict)(output argument):
|
|
|
|
|
key(str): grad variable name
|
|
|
|
|
val(str): corresponding forward variable name
|
|
|
|
|
callback(callable object): a callable object used to decorate new generated grad ops
|
|
|
|
|
callbacks(callable object): a callable object used to decorate new generated grad ops
|
|
|
|
|
input_grad_names_set(set): this set is used to store the gradients' name which is
|
|
|
|
|
generated by backward ops, and input_grad_names_set can help to prune the unnecessary
|
|
|
|
|
backward ops.
|
|
|
|
|
"""
|
|
|
|
|
if callbacks is not None:
|
|
|
|
|
assert (isinstance(callbacks, list))
|
|
|
|
@ -342,6 +464,10 @@ def _append_backward_ops_(block,
|
|
|
|
|
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
|
|
|
|
|
no_grad_dict[block.idx])
|
|
|
|
|
|
|
|
|
|
not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set)
|
|
|
|
|
grad_op_descs = [
|
|
|
|
|
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
|
|
|
|
|
]
|
|
|
|
|
# append op_desc in grad_op_descs to target_block
|
|
|
|
|
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
|
|
|
|
|
backward = core.op_proto_and_checker_maker.OpRole.Backward
|
|
|
|
|