|
|
@ -125,7 +125,7 @@ def init_parallel_env():
|
|
|
|
if ParallelEnv().world_size < 2:
|
|
|
|
if ParallelEnv().world_size < 2:
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 3: init gloo context
|
|
|
|
# 3: init gloo context (step 1: httpsever start)
|
|
|
|
ep_rank_0 = ParallelEnv().trainer_endpoints[0].split(":")
|
|
|
|
ep_rank_0 = ParallelEnv().trainer_endpoints[0].split(":")
|
|
|
|
ep_rank = ParallelEnv().trainer_endpoints[ParallelEnv().rank].split(":")
|
|
|
|
ep_rank = ParallelEnv().trainer_endpoints[ParallelEnv().rank].split(":")
|
|
|
|
manager = Manager()
|
|
|
|
manager = Manager()
|
|
|
@ -138,22 +138,6 @@ def init_parallel_env():
|
|
|
|
http_server.daemon = True
|
|
|
|
http_server.daemon = True
|
|
|
|
http_server_d["running"] = True
|
|
|
|
http_server_d["running"] = True
|
|
|
|
http_server.start()
|
|
|
|
http_server.start()
|
|
|
|
wait_server_ready([ParallelEnv().trainer_endpoints[0]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gloo_strategy = core.GlooParallelStrategy()
|
|
|
|
|
|
|
|
gloo_strategy.rank = ParallelEnv().rank
|
|
|
|
|
|
|
|
gloo_strategy.rank_num = ParallelEnv().world_size
|
|
|
|
|
|
|
|
gloo_strategy.ip_address = ep_rank_0[0]
|
|
|
|
|
|
|
|
gloo_strategy.ip_port = int(ep_rank_0[1])
|
|
|
|
|
|
|
|
default_init_timeout_seconds = 3600
|
|
|
|
|
|
|
|
default_run_timeout_seconds = 9999999
|
|
|
|
|
|
|
|
gloo_strategy.init_seconds = default_init_timeout_seconds
|
|
|
|
|
|
|
|
gloo_strategy.run_seconds = default_run_timeout_seconds
|
|
|
|
|
|
|
|
gloo = core.GlooParallelContext(gloo_strategy)
|
|
|
|
|
|
|
|
gloo.init()
|
|
|
|
|
|
|
|
if ParallelEnv().rank == 0:
|
|
|
|
|
|
|
|
http_server_d["running"] = False
|
|
|
|
|
|
|
|
http_server.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 4. init NCCL ParallelStrategy
|
|
|
|
# 4. init NCCL ParallelStrategy
|
|
|
|
strategy = ParallelStrategy()
|
|
|
|
strategy = ParallelStrategy()
|
|
|
@ -165,7 +149,7 @@ def init_parallel_env():
|
|
|
|
strategy.current_endpoint = ParallelEnv().current_endpoint
|
|
|
|
strategy.current_endpoint = ParallelEnv().current_endpoint
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE(chenweihang): [ why config global place here? ]
|
|
|
|
# NOTE(chenweihang): [ why config global place here? ]
|
|
|
|
# the dygraph mode will be set to default mode,
|
|
|
|
# the dygraph mode will be set to default mode,
|
|
|
|
# users will not call `dygraph.guard` or `enable_dygraph`
|
|
|
|
# users will not call `dygraph.guard` or `enable_dygraph`
|
|
|
|
# directly, if they want to switch default place,
|
|
|
|
# directly, if they want to switch default place,
|
|
|
|
# they need to call a function to change default place,
|
|
|
|
# they need to call a function to change default place,
|
|
|
@ -177,6 +161,27 @@ def init_parallel_env():
|
|
|
|
parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place))
|
|
|
|
parallel_helper._set_parallel_ctx(core.NCCLParallelContext(strategy, place))
|
|
|
|
parallel_helper._init_parallel_ctx()
|
|
|
|
parallel_helper._init_parallel_ctx()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 5: init gloo context (step 2: gloo init)
|
|
|
|
|
|
|
|
# dividing init_gloo into two part beacause nccl and gloo
|
|
|
|
|
|
|
|
# are separately looking for free ports which sometimes
|
|
|
|
|
|
|
|
# leads to port-conflict.
|
|
|
|
|
|
|
|
wait_server_ready([ParallelEnv().trainer_endpoints[0]])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gloo_strategy = core.GlooParallelStrategy()
|
|
|
|
|
|
|
|
gloo_strategy.rank = ParallelEnv().rank
|
|
|
|
|
|
|
|
gloo_strategy.rank_num = ParallelEnv().world_size
|
|
|
|
|
|
|
|
gloo_strategy.ip_address = ep_rank_0[0]
|
|
|
|
|
|
|
|
gloo_strategy.ip_port = int(ep_rank_0[1])
|
|
|
|
|
|
|
|
default_init_timeout_seconds = 3600
|
|
|
|
|
|
|
|
default_run_timeout_seconds = 9999999
|
|
|
|
|
|
|
|
gloo_strategy.init_seconds = default_init_timeout_seconds
|
|
|
|
|
|
|
|
gloo_strategy.run_seconds = default_run_timeout_seconds
|
|
|
|
|
|
|
|
gloo = core.GlooParallelContext(gloo_strategy)
|
|
|
|
|
|
|
|
gloo.init()
|
|
|
|
|
|
|
|
if ParallelEnv().rank == 0:
|
|
|
|
|
|
|
|
http_server_d["running"] = False
|
|
|
|
|
|
|
|
http_server.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_rank():
|
|
|
|
def get_rank():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|