|
|
|
@ -199,6 +199,47 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
|
|
|
|
|
return op_descs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _callback_lookup_(op):
|
|
|
|
|
"""
|
|
|
|
|
Only used in _append_backward_ops_
|
|
|
|
|
Build and returns a callback function for certain op. For example
|
|
|
|
|
|
|
|
|
|
parallel_do: AllReduce
|
|
|
|
|
|
|
|
|
|
:param op:
|
|
|
|
|
:return: callback function
|
|
|
|
|
"""
|
|
|
|
|
print(op.type)
|
|
|
|
|
if op.type == 'parallel_do':
|
|
|
|
|
param_names = set(op.input('parameters'))
|
|
|
|
|
param_grad_names = [n + "@GRAD" for n in param_names]
|
|
|
|
|
|
|
|
|
|
class ParallelDoCallBack(object):
|
|
|
|
|
def __init__(self, param_grad_names):
|
|
|
|
|
self.has_inserted_nccl_init = False
|
|
|
|
|
self.param_grad_names = param_grad_names
|
|
|
|
|
|
|
|
|
|
def __call__(self, block, context):
|
|
|
|
|
# TODO(tonyyang-svail): insert nccl init
|
|
|
|
|
|
|
|
|
|
for o_param in context.output_names():
|
|
|
|
|
for o_argu in context.output(o_param):
|
|
|
|
|
if o_argu in self.param_grad_names:
|
|
|
|
|
print("reduce", o_argu)
|
|
|
|
|
op_desc = block.desc.append_op()
|
|
|
|
|
framework.Operator(
|
|
|
|
|
block,
|
|
|
|
|
type='fill_constant',
|
|
|
|
|
desc=op_desc,
|
|
|
|
|
inputs={},
|
|
|
|
|
attrs={'shape': [1], },
|
|
|
|
|
outputs={'Out': [block.create_var()]})
|
|
|
|
|
|
|
|
|
|
return ParallelDoCallBack(param_grad_names)
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_backward_ops_(block,
|
|
|
|
|
ops,
|
|
|
|
|
target_block,
|
|
|
|
@ -239,7 +280,8 @@ def _append_backward_ops_(block,
|
|
|
|
|
sub_block = program.block(op.block_attr("sub_block"))
|
|
|
|
|
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
|
|
|
|
|
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
|
|
|
|
|
no_grad_dict, grad_to_var)
|
|
|
|
|
no_grad_dict, grad_to_var,
|
|
|
|
|
_callback_lookup_(op))
|
|
|
|
|
grad_sub_block_list.append(grad_sub_block.desc)
|
|
|
|
|
|
|
|
|
|
# Getting op's corresponding grad_op
|
|
|
|
@ -258,7 +300,7 @@ def _append_backward_ops_(block,
|
|
|
|
|
for op_desc in grad_op_descs:
|
|
|
|
|
new_op_desc = target_block.desc.append_op()
|
|
|
|
|
new_op_desc.copy_from(op_desc)
|
|
|
|
|
callback(block=target_block, context=grad_to_var)
|
|
|
|
|
callback(block=target_block, context=new_op_desc)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
|
|
|
|
|