|
|
|
@ -228,8 +228,6 @@ def _callback_lookup_(op):
|
|
|
|
|
self.param_grad_names = param_grad_names
|
|
|
|
|
|
|
|
|
|
def __call__(self, block, context):
|
|
|
|
|
# move to parallel_do.py
|
|
|
|
|
# # TODO(tonyyang-svail): insert nccl init
|
|
|
|
|
if not self.has_inserted_nccl_init:
|
|
|
|
|
global_block = block.program.global_block()
|
|
|
|
|
op_desc = global_block.desc.append_op()
|
|
|
|
@ -250,16 +248,30 @@ def _callback_lookup_(op):
|
|
|
|
|
for o_param in current_op_desc.output_names():
|
|
|
|
|
for o_argu in current_op_desc.output(o_param):
|
|
|
|
|
if o_argu in self.param_grad_names:
|
|
|
|
|
# print("reduce", o_argu)
|
|
|
|
|
op_desc = block.desc.append_op()
|
|
|
|
|
op_desc.set_type("ncclAllReduce")
|
|
|
|
|
op_desc.set_input("X", [o_argu])
|
|
|
|
|
# FIXME(tonyyang-svail):
|
|
|
|
|
# Looks like nccl_com has been changed to nccl_com_0
|
|
|
|
|
op_desc.set_input("Communicator", ['nccl_com_0'])
|
|
|
|
|
out_var = block.create_var()
|
|
|
|
|
op_desc.set_output("Out", [out_var.name])
|
|
|
|
|
op_desc.set_attr("reduction", "ncclSum")
|
|
|
|
|
# # print("reduce", o_argu)
|
|
|
|
|
# op_desc = block.desc.append_op()
|
|
|
|
|
# op_desc.set_type("ncclAllReduce")
|
|
|
|
|
# op_desc.set_input("X", [o_argu])
|
|
|
|
|
#
|
|
|
|
|
# # FIXME(tonyyang-svail):
|
|
|
|
|
# # Looks like nccl_com has been changed to nccl_com_0
|
|
|
|
|
# op_desc.set_input("Communicator", ['nccl_com_0'])
|
|
|
|
|
# out_var = block.create_var()
|
|
|
|
|
# op_desc.set_output("Out", [out_var.name])
|
|
|
|
|
# op_desc.set_attr("reduction", "ncclSum")
|
|
|
|
|
allreduce_out_name = o_argu + "__nccl_all_reduce__"
|
|
|
|
|
op_desc = _create_op_desc_(
|
|
|
|
|
"ncclAllReduce", {
|
|
|
|
|
"X": [o_argu],
|
|
|
|
|
"Communicator": ['nccl_com_0']
|
|
|
|
|
}, {"Out": [allreduce_out_name]},
|
|
|
|
|
{"reduction": "ncclSum"})
|
|
|
|
|
block.desc.append_op().copy_from(op_desc)
|
|
|
|
|
|
|
|
|
|
op_desc = _create_op_desc_(
|
|
|
|
|
"assign", {"X": [allreduce_out_name]},
|
|
|
|
|
{"Out": [o_argu]}, {})
|
|
|
|
|
block.desc.append_op().copy_from(op_desc)
|
|
|
|
|
|
|
|
|
|
return ParallelDoCallBack(param_grad_names)
|
|
|
|
|
else:
|
|
|
|
|