|
|
|
@ -16,7 +16,7 @@ from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
|
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
|
|
|
|
|
'UserDefinedCollectiveRoleMaker'
|
|
|
|
|
'UserDefinedCollectiveRoleMaker', 'PaddleCloudRoleMaker'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -292,6 +292,50 @@ class MPISymetricRoleMaker(MPIRoleMaker):
|
|
|
|
|
self._role_is_generated = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PaddleCloudRoleMaker(RoleMakerBase):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(PaddleCloudRoleMaker, self).__init__()
|
|
|
|
|
|
|
|
|
|
def generate_role(self):
|
|
|
|
|
if not self._role_is_generated:
|
|
|
|
|
self.port = os.getenv("PADDLE_PORT", "6174")
|
|
|
|
|
self.pserver_ips = os.getenv("PADDLE_PSERVERS", "")
|
|
|
|
|
eplist = []
|
|
|
|
|
for ip in pserver_ips.split(","):
|
|
|
|
|
eplist.append(':'.join([ip, port]))
|
|
|
|
|
self.endpoints = ",".join(eplist)
|
|
|
|
|
self.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
|
|
|
|
|
self.current_endpoint = os.getenv("POD_IP",
|
|
|
|
|
"localhost") + ":" + port
|
|
|
|
|
self.role = os.getenv("TRAINING_ROLE", "TRAINER")
|
|
|
|
|
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
|
|
|
|
|
self.eplist = eplist
|
|
|
|
|
self.endpoints = self.endpoints.split(",")
|
|
|
|
|
if self.role.upper() == "PSERVER":
|
|
|
|
|
self.current_id = self.endpoints.index(self.current_endpoint)
|
|
|
|
|
else:
|
|
|
|
|
self.current_id = self.trainer_id
|
|
|
|
|
self._role_is_generated = True
|
|
|
|
|
|
|
|
|
|
def is_wokrer(self):
|
|
|
|
|
return self._role == Role.WORKER
|
|
|
|
|
|
|
|
|
|
def is_server(self):
|
|
|
|
|
return self._role == Role.SERVER
|
|
|
|
|
|
|
|
|
|
def is_first_worker(self):
|
|
|
|
|
return self._role == Role.WORKER and self._current_id == 0
|
|
|
|
|
|
|
|
|
|
def worker_index(self):
|
|
|
|
|
return self._current_id
|
|
|
|
|
|
|
|
|
|
def server_index(self):
|
|
|
|
|
return self._current_id
|
|
|
|
|
|
|
|
|
|
def worker_num(self):
|
|
|
|
|
return self._worker_num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UserDefinedRoleMaker(RoleMakerBase):
|
|
|
|
|
def __init__(self,
|
|
|
|
|
current_id=0,
|
|
|
|
@ -329,6 +373,9 @@ class UserDefinedRoleMaker(RoleMakerBase):
|
|
|
|
|
else:
|
|
|
|
|
self._server_endpoints = server_endpoints
|
|
|
|
|
|
|
|
|
|
def generate_role(self):
|
|
|
|
|
self._role_is_generated = True
|
|
|
|
|
|
|
|
|
|
def is_worker(self):
|
|
|
|
|
return self._role == Role.WORKER
|
|
|
|
|
|
|
|
|
@ -369,6 +416,9 @@ class UserDefinedCollectiveRoleMaker(RoleMakerBase):
|
|
|
|
|
self._worker_endpoints = worker_endpoints
|
|
|
|
|
self._worker_num = len(self._worker_endpoints)
|
|
|
|
|
|
|
|
|
|
def generate_role(self):
|
|
|
|
|
self._role_is_generated = True
|
|
|
|
|
|
|
|
|
|
def is_worker(self):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|