|
|
|
@ -85,12 +85,6 @@ class Collective(Fleet):
|
|
|
|
|
def save_persistables(self, executor, dirname, main_program=None):
|
|
|
|
|
io.save_persistables(executor, dirname, main_program, None)
|
|
|
|
|
|
|
|
|
|
def node_num(self):
|
|
|
|
|
return self._role_maker._node_num
|
|
|
|
|
|
|
|
|
|
def node_id(self):
|
|
|
|
|
return self._role_maker._node_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fleet = Collective()
|
|
|
|
|
|
|
|
|
@ -102,9 +96,6 @@ class DistributedStrategy(fluid.BuildStrategy):
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(DistributedStrategy, self).__init__()
|
|
|
|
|
self.fuse_memory_size = -1
|
|
|
|
|
self.fuse_layer_size = 1
|
|
|
|
|
|
|
|
|
|
self.use_local_sgd = False
|
|
|
|
|
self.use_dist_fc = False
|
|
|
|
|
|
|
|
|
@ -112,21 +103,9 @@ class DistributedStrategy(fluid.BuildStrategy):
|
|
|
|
|
self.dist_fc_config = None # DistFCConfig
|
|
|
|
|
self.mode = "nccl2" # or collective
|
|
|
|
|
self.collective_mode = None # local_sgd or grad_allreduce
|
|
|
|
|
|
|
|
|
|
self.nccl_comm_num = 2
|
|
|
|
|
self.nccl_comm_num = 1
|
|
|
|
|
|
|
|
|
|
self.exec_strategy = fluid.ExecutionStrategy()
|
|
|
|
|
sync_allreduce = os.getenv("FLAGS_sync_nccl_allreduce")
|
|
|
|
|
if sync_allreduce == "0":
|
|
|
|
|
self._exec_strategy.num_threads = self.nccl_comm_num + 1
|
|
|
|
|
if sef.use_hierarchical_allreduce:
|
|
|
|
|
self._exec_strategy.num_threads = 2 * self.nccl_comm_num + 1
|
|
|
|
|
if self._exec_strategy.num_threads > 4:
|
|
|
|
|
print(
|
|
|
|
|
sys.stderr,
|
|
|
|
|
"WARNING: if you use use_hierarchical_allreduce or "
|
|
|
|
|
"with multi nccl comm, please set FLAGS_sync_nccl_allreduce = 0"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CollectiveOpBasedOptimizer(DistributedOptimizer):
|
|
|
|
@ -215,12 +194,6 @@ class CollectiveOptimizer(DistributedOptimizer):
|
|
|
|
|
"""
|
|
|
|
|
Transpile the programs to distributed programs. And add the variables.
|
|
|
|
|
"""
|
|
|
|
|
if self._strategy.fuse_all_reduce_ops:
|
|
|
|
|
os.environ[
|
|
|
|
|
'FLAGS_fuse_parameter_memory_size'] = self.fuse_memory_size
|
|
|
|
|
os.environ[
|
|
|
|
|
'FLAGS_fuse_parameter_groups_size'] = self.fuse_layer_size
|
|
|
|
|
|
|
|
|
|
worker_endpoints = fleet.worker_endpoints()
|
|
|
|
|
trainer_id = fleet.worker_index()
|
|
|
|
|
current_endpoint = fleet.worker_endpoints()[trainer_id]
|
|
|
|
@ -249,7 +222,67 @@ class CollectiveOptimizer(DistributedOptimizer):
|
|
|
|
|
program=main_program,
|
|
|
|
|
current_endpoint=current_endpoint)
|
|
|
|
|
|
|
|
|
|
def _get_node_ips_from_endpoints(self, endpoints):
|
|
|
|
|
ss = set()
|
|
|
|
|
ips = []
|
|
|
|
|
for ep in endpoints:
|
|
|
|
|
ip = ep.split(":")[0].strip()
|
|
|
|
|
if ip not in ss:
|
|
|
|
|
ss.add(ip)
|
|
|
|
|
ips.append(ip)
|
|
|
|
|
else:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
return ips
|
|
|
|
|
|
|
|
|
|
def _node_num(self):
|
|
|
|
|
worker_endpoints = fleet.worker_endpoints()
|
|
|
|
|
current_endpoint = fleet.worker_endpoints()[fleet.worker_index()]
|
|
|
|
|
worker_endpoints_env = ','.join(worker_endpoints)
|
|
|
|
|
|
|
|
|
|
node_ips = self._get_node_ips_from_endpoints(worker_endpoints)
|
|
|
|
|
node_ip = current_endpoint.split(":")[0].strip()
|
|
|
|
|
|
|
|
|
|
node_num = len(node_ips)
|
|
|
|
|
|
|
|
|
|
return node_num
|
|
|
|
|
|
|
|
|
|
def _try_to_compile(self, startup_program, main_program):
|
|
|
|
|
node_num = self._node_num()
|
|
|
|
|
assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num
|
|
|
|
|
|
|
|
|
|
self._strategy.fuse_all_reduce_ops = True
|
|
|
|
|
exec_strategy = self._strategy.exec_strategy
|
|
|
|
|
|
|
|
|
|
if node_num <= 1:
|
|
|
|
|
if self._strategy.nccl_comm_num > 1:
|
|
|
|
|
logging.warn("set nccl_comm_num=1 since you only have 1 node.")
|
|
|
|
|
self._strategy.nccl_comm_num = 1
|
|
|
|
|
|
|
|
|
|
if self._strategy.use_hierarchical_allreduce:
|
|
|
|
|
logging.warn(
|
|
|
|
|
"set use_hierarchical_allreduce=False since you only have 1 node."
|
|
|
|
|
)
|
|
|
|
|
self._strategy.use_hierarchical_allreduce = False
|
|
|
|
|
|
|
|
|
|
sync_allreduce = os.getenv("FLAGS_sync_nccl_allreduce")
|
|
|
|
|
if sync_allreduce is None or sync_allreduce == "1":
|
|
|
|
|
exec_strategy.num_threads = self._strategy.nccl_comm_num + 1
|
|
|
|
|
if self._strategy.use_hierarchical_allreduce:
|
|
|
|
|
exec_strategy.num_threads = 2 * self._strategy.nccl_comm_num + 1
|
|
|
|
|
if exec_strategy.num_threads > 4:
|
|
|
|
|
logging.warn(
|
|
|
|
|
"if you use use_hierarchical_allreduce or "
|
|
|
|
|
"with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.print_config:
|
|
|
|
|
print("node_num:", node_num, "num_threads:",
|
|
|
|
|
exec_strategy.num_threads, "use_hierarchical_allreduce:",
|
|
|
|
|
self._strategy.use_hierarchical_allreduce, "nccl_comm_num:",
|
|
|
|
|
self._strategy.nccl_comm_num, "FLAGS_sync_nccl_allreduce:",
|
|
|
|
|
sync_allreduce)
|
|
|
|
|
|
|
|
|
|
self._transpile(startup_program, main_program)
|
|
|
|
|
|
|
|
|
|
if self._strategy.mode == "collective":
|
|
|
|
|