add UserDefinedCollectiveRoleMaker for collective mode (#17898)

* add 'UserDefinedRoleMakerNCCL' for collective mode.

* code style

* add the name UserDefinedRoleMakerNCCL to __all__

* rename to UserDefinedRoleMakerCollective

* rename to UserDefinedCollectiveRoleMaker
lite
lilong12 6 years ago committed by GitHub
parent 84bb45c054
commit b5c35ae3e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,7 +16,8 @@ from __future__ import print_function
from enum import Enum
__all__ = [
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker'
'Role', 'RoleMakerBase', 'MPISymetricRoleMaker', 'UserDefinedRoleMaker',
'UserDefinedCollectiveRoleMaker'
]
@ -346,3 +347,37 @@ class UserDefinedRoleMaker(RoleMakerBase):
def worker_num(self):
return self._worker_num
class UserDefinedCollectiveRoleMaker(RoleMakerBase):
def __init__(self, current_id=0, worker_endpoints=None):
"""
UserDefinedCollectiveRoleMaker is designed for worker assignment
under manual for collective mode.
"""
super(UserDefinedCollectiveRoleMaker, self).__init__()
if not isinstance(current_id, int):
raise TypeError("current_id must be as int")
else:
if current_id < 0:
raise ValueError("current_id must be greater or equal 0")
self._current_id = current_id
if not isinstance(worker_endpoints, list):
raise TypeError("worker_endpoints must be as string list")
else:
self._worker_endpoints = worker_endpoints
self._worker_num = len(self._worker_endpoints)
def is_worker(self):
return True
def is_first_worker(self):
return self._current_id == 0
def worker_index(self):
return self._current_id
def worker_num(self):
return self._worker_num

Loading…
Cancel
Save