|
|
|
@ -713,6 +713,14 @@ class ParameterServerLauncher(object):
|
|
|
|
|
else:
|
|
|
|
|
self.worker_endpoints = args.workers
|
|
|
|
|
|
|
|
|
|
# get http_port
|
|
|
|
|
if args.http_port:
|
|
|
|
|
self.http_port = args.http_port
|
|
|
|
|
else:
|
|
|
|
|
http_port = get_ports(1, self.server_num + self.worker_num)
|
|
|
|
|
http_ip = self.server_endpoints.split(",")[0].split(":")[0]
|
|
|
|
|
self.http_port = http_ip + ":" + str(http_port[0])
|
|
|
|
|
|
|
|
|
|
# get heter worker envs
|
|
|
|
|
if self.distribute_mode == DistributeMode.PS_HETER:
|
|
|
|
|
if args.heter_worker_num:
|
|
|
|
@ -827,7 +835,8 @@ class ParameterServerLauncher(object):
|
|
|
|
|
|
|
|
|
|
self.start_pod_server(self.args, pod)
|
|
|
|
|
self.start_pod_worker(self.args, pod)
|
|
|
|
|
self.start_pod_heter_worker(self.args, pod)
|
|
|
|
|
if self.distribute_mode == DistributeMode.PS_HETER:
|
|
|
|
|
self.start_pod_heter_worker(self.args, pod)
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
"Please check servers, workers and heter_worker logs in {}/workerlog.*, {}/serverlog.* and {}/heterlog.*".
|
|
|
|
@ -887,7 +896,8 @@ class ParameterServerLauncher(object):
|
|
|
|
|
"POD_IP": cur_server.endpoint.split(":")[0],
|
|
|
|
|
"PADDLE_WITH_GLOO": "1",
|
|
|
|
|
"PADDLE_GLOO_RENDEZVOUS": "2",
|
|
|
|
|
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir
|
|
|
|
|
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
|
|
|
|
|
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
|
|
|
|
|
}
|
|
|
|
|
current_env.update(proc_env)
|
|
|
|
|
|
|
|
|
@ -938,7 +948,8 @@ class ParameterServerLauncher(object):
|
|
|
|
|
device_list = [str(x) for x in range(0, heter_device_num)]
|
|
|
|
|
|
|
|
|
|
for idx, cur_worker in enumerate(pod.workers):
|
|
|
|
|
device_id = str(device_list[idx % heter_device_num])
|
|
|
|
|
device_id = "0" if heter_device_num == 0 else str(device_list[
|
|
|
|
|
idx % heter_device_num])
|
|
|
|
|
proc_env = {
|
|
|
|
|
"PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints,
|
|
|
|
|
"PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints,
|
|
|
|
@ -954,6 +965,7 @@ class ParameterServerLauncher(object):
|
|
|
|
|
"FLAGS_selected_xpus": "0",
|
|
|
|
|
"CUDA_VISIBLE_DEVICES": device_id,
|
|
|
|
|
"XPU_VISIBLE_DEVICES": device_id,
|
|
|
|
|
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
|
|
|
|
|
}
|
|
|
|
|
current_env.update(proc_env)
|
|
|
|
|
|
|
|
|
@ -1022,6 +1034,7 @@ class ParameterServerLauncher(object):
|
|
|
|
|
"FLAGS_selected_xpus": "0",
|
|
|
|
|
"CUDA_VISIBLE_DEVICES": device_id,
|
|
|
|
|
"XPU_VISIBLE_DEVICES": device_id,
|
|
|
|
|
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
|
|
|
|
|
}
|
|
|
|
|
current_env.update(proc_env)
|
|
|
|
|
|
|
|
|
|