|
|
|
@ -47,7 +47,7 @@ def is_optimizer_op(op):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CollectiveHelper(object):
|
|
|
|
|
def __init__(self, role_maker, nrings=1, wait_port='6174'):
|
|
|
|
|
def __init__(self, role_maker, nrings=1, wait_port=True):
|
|
|
|
|
self.nrings = nrings
|
|
|
|
|
self.wait_port = wait_port
|
|
|
|
|
self.role_maker = role_maker
|
|
|
|
@ -65,14 +65,48 @@ class CollectiveHelper(object):
|
|
|
|
|
self.role_maker._worker_index(), ring_id, self.wait_port)
|
|
|
|
|
self._broadcast_params()
|
|
|
|
|
|
|
|
|
|
def _init_communicator(self, program, current_endpoint, endpoints, rank,
|
|
|
|
|
ring_id, wait_port):
|
|
|
|
|
def _init_communicator(self,
|
|
|
|
|
program,
|
|
|
|
|
current_endpoint,
|
|
|
|
|
endpoints,
|
|
|
|
|
rank,
|
|
|
|
|
ring_id,
|
|
|
|
|
wait_port,
|
|
|
|
|
global_ring_id=None,
|
|
|
|
|
sync=True):
|
|
|
|
|
nranks = len(endpoints)
|
|
|
|
|
other_endpoints = endpoints[:]
|
|
|
|
|
other_endpoints.remove(current_endpoint)
|
|
|
|
|
if rank == 0 and wait_port:
|
|
|
|
|
wait_server_ready(other_endpoints)
|
|
|
|
|
|
|
|
|
|
def _add_sync_by_allreduce(block):
|
|
|
|
|
sync_var = block.create_var(
|
|
|
|
|
name=unique_name.generate('sync_var'),
|
|
|
|
|
dtype=core.VarDesc.VarType.INT32,
|
|
|
|
|
persistable=False,
|
|
|
|
|
stop_gradient=True)
|
|
|
|
|
block.append_op(
|
|
|
|
|
type='fill_constant',
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={'Out': [sync_var]},
|
|
|
|
|
attrs={
|
|
|
|
|
'shape': [1],
|
|
|
|
|
'dtype': sync_var.dtype,
|
|
|
|
|
'value': 1,
|
|
|
|
|
'force_cpu': False,
|
|
|
|
|
OP_ROLE_KEY: OpRole.Forward
|
|
|
|
|
})
|
|
|
|
|
block.append_op(
|
|
|
|
|
type='c_allreduce_sum',
|
|
|
|
|
inputs={'X': [sync_var]},
|
|
|
|
|
outputs={'Out': [sync_var]},
|
|
|
|
|
attrs={
|
|
|
|
|
'ring_id': global_ring_id,
|
|
|
|
|
'use_calc_stream': True,
|
|
|
|
|
OP_ROLE_KEY: OpRole.Forward
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
block = program.global_block()
|
|
|
|
|
if core.is_compiled_with_cuda():
|
|
|
|
|
comm_id_var = block.create_var(
|
|
|
|
@ -128,6 +162,7 @@ class CollectiveHelper(object):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"comm_id must be generated in paddlepaddle-xpu or paddlepaddle-xpu."
|
|
|
|
|
)
|
|
|
|
|
if sync: _add_sync_by_allreduce(block)
|
|
|
|
|
|
|
|
|
|
def _wait(self, current_endpoint, endpoints):
|
|
|
|
|
assert (self.wait_port)
|
|
|
|
|