add parameter server launch (#18687)
add parameter server launch so that a user can easily launch parameter serverDDDivano-patch-1
parent
d07ad4c605
commit
70b03760fd
@ -0,0 +1,151 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
from argparse import ArgumentParser, REMAINDER
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
# Optional arguments for the launch helper
|
||||||
|
parser = ArgumentParser(description="Distributed training")
|
||||||
|
parser.add_argument(
|
||||||
|
"--cluster_node_ips",
|
||||||
|
type=str,
|
||||||
|
default="127.0.0.1",
|
||||||
|
help="Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17..")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--node_ip",
|
||||||
|
type=str,
|
||||||
|
default="127.0.0.1",
|
||||||
|
help="The current node ip. ")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_port",
|
||||||
|
type=int,
|
||||||
|
default=6170,
|
||||||
|
help="The trainer's start port on a single node")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--print_config",
|
||||||
|
type=bool,
|
||||||
|
default=True,
|
||||||
|
help="Print the config or not")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--worker_num", type=int, default=2, help="number of workers")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--server_num", type=int, default=2, help="number of servers")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_dir",
|
||||||
|
default="logs",
|
||||||
|
type=str,
|
||||||
|
help="The path for each process's log.If it's not setted, the log will printed to default pipe."
|
||||||
|
)
|
||||||
|
|
||||||
|
# positional
|
||||||
|
parser.add_argument(
|
||||||
|
"training_script",
|
||||||
|
type=str,
|
||||||
|
help="The full path to the single GPU training "
|
||||||
|
"program/script to be launched in parallel, "
|
||||||
|
"followed by all the arguments for the "
|
||||||
|
"training script")
|
||||||
|
|
||||||
|
# rest from the training program
|
||||||
|
parser.add_argument('training_script_args', nargs=REMAINDER)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def start_procs(args):
|
||||||
|
worker_num = args.worker_num
|
||||||
|
server_num = args.server_num
|
||||||
|
start_port = args.start_port
|
||||||
|
default_env = os.environ.copy()
|
||||||
|
current_env = copy.copy(default_env)
|
||||||
|
current_env.pop("http_proxy", None)
|
||||||
|
current_env.pop("https_proxy", None)
|
||||||
|
procs = []
|
||||||
|
cmds = []
|
||||||
|
log_fns = []
|
||||||
|
ports = range(start_port, start_port + server_num, 1)
|
||||||
|
endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
|
||||||
|
for i in range(server_num):
|
||||||
|
current_env.update({
|
||||||
|
"TRAINER_NUM": str(worker_num),
|
||||||
|
"CURRENT_ID": str(i),
|
||||||
|
"ENDPOINTS": endpoints,
|
||||||
|
"TRAINING_ROLE": "PSERVER"
|
||||||
|
})
|
||||||
|
cmd = [sys.executable, "-u", args.training_script
|
||||||
|
] + args.training_script_args
|
||||||
|
cmds.append(cmd)
|
||||||
|
print(cmd)
|
||||||
|
if args.log_dir is not None:
|
||||||
|
os.system("mkdir -p {}".format(args.log_dir))
|
||||||
|
fn = open("%s/serverlog.%d" % (args.log_dir, i), "w")
|
||||||
|
log_fns.append(fn)
|
||||||
|
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
|
||||||
|
else:
|
||||||
|
proc = subprocess.Popen(cmd, env=current_env)
|
||||||
|
procs.append(proc)
|
||||||
|
|
||||||
|
for i in range(worker_num):
|
||||||
|
current_env.update({
|
||||||
|
"ENDPOINTS": endpoints,
|
||||||
|
"TRAINER_NUM": str(worker_num),
|
||||||
|
"TRAINING_ROLE": "TRAINER",
|
||||||
|
"CURRENT_ID": str(i)
|
||||||
|
})
|
||||||
|
cmd = [sys.executable, "-u", args.training_script
|
||||||
|
] + args.training_script_args
|
||||||
|
print(cmd)
|
||||||
|
cmds.append(cmd)
|
||||||
|
if args.log_dir is not None:
|
||||||
|
os.system("mkdir -p {}".format(args.log_dir))
|
||||||
|
fn = open("%s/workerlog.%d" % (args.log_dir, i), "w")
|
||||||
|
log_fns.append(fn)
|
||||||
|
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
|
||||||
|
else:
|
||||||
|
proc = subprocess.Popen(cmd, env=current_env)
|
||||||
|
procs.append(proc)
|
||||||
|
|
||||||
|
for i in range(0, len(procs)):
|
||||||
|
proc = procs[i]
|
||||||
|
|
||||||
|
proc.wait()
|
||||||
|
if len(log_fns) > 0:
|
||||||
|
log_fns[i].close()
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise subprocess.CalledProcessError(
|
||||||
|
returncode=procs[i].returncode, cmd=cmds[i])
|
||||||
|
|
||||||
|
|
||||||
|
def launch():
|
||||||
|
args = parse_args()
|
||||||
|
if args.print_config:
|
||||||
|
start_procs(args)
|
||||||
|
|
||||||
|
|
||||||
|
# server num, worker num
|
||||||
|
if __name__ == "__main__":
|
||||||
|
launch()
|
Loading…
Reference in new issue