parent
23b6dfd07a
commit
212dcedc48
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,7 @@
|
||||
FROM paddledev/paddle:cpu-latest
|
||||
|
||||
MAINTAINER zjsxzong89@gmail.com
|
||||
|
||||
COPY start.sh /root/
|
||||
COPY start_paddle.py /root/
|
||||
CMD ["bash"," -c","/root/start.sh"]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,43 @@
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: paddle-cluster-job
|
||||
spec:
|
||||
parallelism: 3
|
||||
completions: 3
|
||||
template:
|
||||
metadata:
|
||||
name: paddle-cluster-job
|
||||
spec:
|
||||
volumes:
|
||||
- name: jobpath
|
||||
hostPath:
|
||||
path: /home/work/paddle_output
|
||||
containers:
|
||||
- name: trainer
|
||||
image: registry.baidu.com/public/paddle:mypaddle
|
||||
command: ["bin/bash", "-c", "/root/start.sh"]
|
||||
env:
|
||||
- name: JOB_NAME
|
||||
value: paddle-cluster-job
|
||||
- name: JOB_PATH
|
||||
value: /home/jobpath
|
||||
- name: JOB_NAMESPACE
|
||||
value: default
|
||||
- name: TRAIN_CONFIG_DIR
|
||||
value: recommendation
|
||||
- name: CONF_PADDLE_NIC
|
||||
value: eth0
|
||||
- name: CONF_PADDLE_PORT
|
||||
value: "7164"
|
||||
- name: CONF_PADDLE_PORTS_NUM
|
||||
value: "2"
|
||||
- name: CONF_PADDLE_PORTS_NUM_SPARSE
|
||||
value: "2"
|
||||
- name: CONF_PADDLE_GRADIENT_NUM
|
||||
value: "3"
|
||||
volumeMounts:
|
||||
- name: jobpath
|
||||
mountPath: /home/jobpath
|
||||
restartPolicy: Never
|
||||
|
@ -0,0 +1,19 @@
|
||||
#!/bin/sh
|
||||
set -eu
|
||||
|
||||
jobconfig=${JOB_PATH}"/"${JOB_NAME}"/"${TRAIN_CONFIG_DIR}
|
||||
cd /root
|
||||
cp -rf $jobconfig .
|
||||
cd $TRAIN_CONFIG_DIR
|
||||
|
||||
|
||||
python /root/start_paddle.py \
|
||||
--dot_period=10 \
|
||||
--ports_num_for_sparse=$CONF_PADDLE_PORTS_NUM \
|
||||
--log_period=50 \
|
||||
--num_passes=10 \
|
||||
--trainer_count=4 \
|
||||
--saving_period=1 \
|
||||
--local=0 \
|
||||
--config=./trainer_config.py \
|
||||
--use_gpu=0
|
@ -0,0 +1,159 @@
|
||||
#!/usr/bin/python
|
||||
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import requests
|
||||
import time
|
||||
import socket
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
# configuration for cluster
|
||||
API = "/api/v1/namespaces/"
|
||||
JOBSELECTOR = "labelSelector=job-name="
|
||||
JOB_PATH = os.getenv("JOB_PATH") + "/" + os.getenv("JOB_NAME")
|
||||
JOB_PATH_DATA = JOB_PATH + "/data"
|
||||
JOB_PATH_OUTPUT = JOB_PATH + "/output"
|
||||
JOBNAME = os.getenv("JOB_NAME")
|
||||
NAMESPACE = os.getenv("JOB_NAMESPACE")
|
||||
PADDLE_NIC = os.getenv("CONF_PADDLE_NIC")
|
||||
PADDLE_PORT = os.getenv("CONF_PADDLE_PORT")
|
||||
PADDLE_PORTS_NUM = os.getenv("CONF_PADDLE_PORTS_NUM")
|
||||
PADDLE_PORTS_NUM_SPARSE = os.getenv("CONF_PADDLE_PORTS_NUM_SPARSE")
|
||||
PADDLE_SERVER_NUM = os.getenv("CONF_PADDLE_GRADIENT_NUM")
|
||||
|
||||
|
||||
def refine_unknown_args(cmd_args):
|
||||
'''
|
||||
refine unknown parameters to handle some special parameters
|
||||
'''
|
||||
new_args = []
|
||||
for arg in cmd_args:
|
||||
if arg.startswith("--") and arg.find("=") != -1:
|
||||
equal_pos = arg.find("=") # find first = pos
|
||||
arglist = list(arg)
|
||||
arglist[equal_pos] = " "
|
||||
arg = "".join(arglist)
|
||||
arg = arg.lstrip("-")
|
||||
new_args += arg.split(" ")
|
||||
elif arg.startswith("--") and arg.find("=") == -1:
|
||||
arg = arg.lstrip("-")
|
||||
new_args.append(arg)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
|
||||
def isPodAllRunning(podlist):
|
||||
'''
|
||||
check all pod is running
|
||||
'''
|
||||
require = len(podlist["items"])
|
||||
running = 0
|
||||
for pod in podlist["items"]:
|
||||
if pod["status"]["phase"] == "Running":
|
||||
running += 1
|
||||
if require == running:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def getPodList():
|
||||
'''
|
||||
get all container status of the job
|
||||
'''
|
||||
apiserver = "https://" + \
|
||||
os.getenv("KUBERNETES_SERVICE_HOST") + ":" + \
|
||||
os.getenv("KUBERNETES_SERVICE_PORT_HTTPS")
|
||||
|
||||
pod = API + NAMESPACE + "/pods?"
|
||||
job = JOBNAME
|
||||
return requests.get(apiserver + pod + JOBSELECTOR + job,
|
||||
verify=False).json()
|
||||
|
||||
|
||||
def getIdMap(podlist):
|
||||
'''
|
||||
generate tainer_id by ip
|
||||
'''
|
||||
ips = []
|
||||
for pod in podlist["items"]:
|
||||
ips.append(pod["status"]["podIP"])
|
||||
ips.sort()
|
||||
idMap = {}
|
||||
for i in range(len(ips)):
|
||||
idMap[ips[i]] = i
|
||||
return idMap
|
||||
|
||||
|
||||
def startPaddle(idMap={}, train_args_dict=None):
|
||||
'''
|
||||
start paddle pserver and trainer
|
||||
'''
|
||||
program = 'paddle train'
|
||||
args = " --nics=" + PADDLE_NIC
|
||||
args += " --port=" + str(PADDLE_PORT)
|
||||
args += " --ports_num=" + str(PADDLE_PORTS_NUM)
|
||||
args += " --comment=" + "paddle_process_by_paddle"
|
||||
ip_string = ""
|
||||
for ip in idMap.keys():
|
||||
ip_string += (ip + ",")
|
||||
ip_string = ip_string.rstrip(",")
|
||||
args += " --pservers=" + ip_string
|
||||
args_ext = ""
|
||||
for key, value in train_args_dict.items():
|
||||
args_ext += (' --' + key + '=' + value)
|
||||
localIP = socket.gethostbyname(socket.gethostname())
|
||||
trainerId = idMap[localIP]
|
||||
args += " " + args_ext + " --trainer_id=" + \
|
||||
str(trainerId) + " --save_dir=" + JOB_PATH_OUTPUT
|
||||
logDir = JOB_PATH_OUTPUT + "/node_" + str(trainerId)
|
||||
if not os.path.exists(JOB_PATH_OUTPUT):
|
||||
os.makedirs(JOB_PATH_OUTPUT)
|
||||
os.mkdir(logDir)
|
||||
copyCommand = 'cp -rf ' + JOB_PATH_DATA + \
|
||||
"/" + str(trainerId) + " ./data"
|
||||
os.system(copyCommand)
|
||||
startPserver = 'nohup paddle pserver' + \
|
||||
" --port=" + str(PADDLE_PORT) + \
|
||||
" --ports_num=" + str(PADDLE_PORTS_NUM) + \
|
||||
" --ports_num_for_sparse=" + str(PADDLE_PORTS_NUM_SPARSE) + \
|
||||
" --nics=" + PADDLE_NIC + \
|
||||
" --comment=" + "paddle_process_by_paddle" + \
|
||||
" --num_gradient_servers=" + str(PADDLE_SERVER_NUM) +\
|
||||
" > " + logDir + "/server.log 2>&1 &"
|
||||
print startPserver
|
||||
os.system(startPserver)
|
||||
# wait until pservers completely start
|
||||
time.sleep(10)
|
||||
startTrainer = program + args + " > " + \
|
||||
logDir + "/train.log 2>&1 < /dev/null"
|
||||
print startTrainer
|
||||
os.system(startTrainer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(prog="start_paddle.py",
|
||||
description='simple tool for k8s')
|
||||
args, train_args_list = parser.parse_known_args()
|
||||
train_args = refine_unknown_args(train_args_list)
|
||||
train_args_dict = dict(zip(train_args[:-1:2], train_args[1::2]))
|
||||
podlist = getPodList()
|
||||
# need to wait until all pods are running
|
||||
while not isPodAllRunning(podlist):
|
||||
time.sleep(10)
|
||||
podlist = getPodList()
|
||||
idMap = getIdMap(podlist)
|
||||
startPaddle(idMap, train_args_dict)
|
Loading…
Reference in new issue