|
|
|
@ -248,12 +248,15 @@ def _callback_lookup_(op):
|
|
|
|
|
if o_argu in self.param_grad_names:
|
|
|
|
|
allreduce_out_name = o_argu + "__nccl_all_reduce__"
|
|
|
|
|
op_desc = _create_op_desc_(
|
|
|
|
|
"ncclAllReduce", {
|
|
|
|
|
"ncclReduce",
|
|
|
|
|
{
|
|
|
|
|
"X": [o_argu],
|
|
|
|
|
"Communicator":
|
|
|
|
|
['nccl_com__do_not_change_']
|
|
|
|
|
}, {"Out": [allreduce_out_name]},
|
|
|
|
|
{"reduction": "ncclSum"})
|
|
|
|
|
},
|
|
|
|
|
{"Out": [allreduce_out_name]},
|
|
|
|
|
{"reduction": "ncclSum",
|
|
|
|
|
"root": 0}, )
|
|
|
|
|
block.desc.append_op().copy_from(op_desc)
|
|
|
|
|
|
|
|
|
|
op_desc = _create_op_desc_(
|
|
|
|
|