|
|
|
@ -230,44 +230,19 @@ def _callback_lookup_(op):
|
|
|
|
|
|
|
|
|
|
def __call__(self, block, context):
|
|
|
|
|
if not self.has_inserted_nccl_init:
|
|
|
|
|
# global_block = block.program.global_block()
|
|
|
|
|
# op_desc = global_block.desc.append_op()
|
|
|
|
|
# var_desc = global_block.desc.var('nccl_com__do_not_change_')
|
|
|
|
|
# var_desc.set_type(core.VarDesc.VarType.NCCL_COM)
|
|
|
|
|
# self.nccl_com = global_block.create_var(
|
|
|
|
|
# name='nccl_com', type=core.VarDesc.VarType.NCCL_COM)
|
|
|
|
|
# framework.Operator(
|
|
|
|
|
# global_block,
|
|
|
|
|
# type='ncclInit',
|
|
|
|
|
# desc=op_desc,
|
|
|
|
|
# inputs={},
|
|
|
|
|
# outputs={'Communicator': [self.nccl_com]})
|
|
|
|
|
op_desc = _create_op_desc_(
|
|
|
|
|
"ncclInit",
|
|
|
|
|
{"parallel_scopes": self.parallel_scopes_name},
|
|
|
|
|
{"Communicator": ['nccl_com__do_not_change_']}, {})
|
|
|
|
|
# block.desc.append_op().copy_from(op_desc)
|
|
|
|
|
print(serialize_op_decs(op_desc))
|
|
|
|
|
block.program.global_block().desc.append_op().copy_from(
|
|
|
|
|
op_desc)
|
|
|
|
|
self.has_inserted_nccl_init = True
|
|
|
|
|
|
|
|
|
|
current_op_desc = context["__current_op_desc__"]
|
|
|
|
|
# print(serialize_op_decs(context))
|
|
|
|
|
for o_param in current_op_desc.output_names():
|
|
|
|
|
for o_argu in current_op_desc.output(o_param):
|
|
|
|
|
if o_argu in self.param_grad_names:
|
|
|
|
|
# # print("reduce", o_argu)
|
|
|
|
|
# op_desc = block.desc.append_op()
|
|
|
|
|
# op_desc.set_type("ncclAllReduce")
|
|
|
|
|
# op_desc.set_input("X", [o_argu])
|
|
|
|
|
#
|
|
|
|
|
# # FIXME(tonyyang-svail):
|
|
|
|
|
# # Looks like nccl_com has been changed to nccl_com_0
|
|
|
|
|
# op_desc.set_input("Communicator", ['nccl_com_0'])
|
|
|
|
|
# out_var = block.create_var()
|
|
|
|
|
# op_desc.set_output("Out", [out_var.name])
|
|
|
|
|
# op_desc.set_attr("reduction", "ncclSum")
|
|
|
|
|
allreduce_out_name = o_argu + "__nccl_all_reduce__"
|
|
|
|
|
op_desc = _create_op_desc_(
|
|
|
|
|
"ncclAllReduce", {
|
|
|
|
|