You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/benchmark/cluster/vgg16/tf_k8s

83 lines
2.1 KiB

#!/bin/bash
check_trainer_ret() {
ret=$1
stdbuf -oL echo "job returned $ret...setting pod return message..."
stdbuf -oL echo "==============================="
if [ $ret -eq 136 ] ; then
echo "Error Arithmetic Operation(Floating Point Exception)" > /dev/termination-log
elif [ $ret -eq 139 ] ; then
echo "Segmentation Fault" > /dev/termination-log
elif [ $ret -eq 1 ] ; then
echo "General Error" > /dev/termination-log
elif [ $ret -eq 134 ] ; then
echo "Program Abort" > /dev/termination-log
fi
stdbuf -oL echo "termination log wroted..."
exit $ret
}
g_pservers=""
g_trainers=""
wait_running_pods(){
pserver_label="tf-job-pserver=${JOB_NAME}"
trainer_label="tf-job-trainer=${JOB_NAME}"
stdbuf -oL python /root/k8s_tools.py wait_pods_running ${pserver_label} ${PSERVERS_NUM}
stdbuf -oL python /root/k8s_tools.py wait_pods_running ${trainer_label} ${TRAINERS_NUM}
g_pservers=$(python /root/k8s_tools.py fetch_endpoints ${pserver_label} ${PORT})
g_trainers=$(python /root/k8s_tools.py fetch_endpoints ${trainer_label} ${PORT})
}
start_tf_pserver(){
wait_running_pods
label="tf-job-pserver=${JOB_NAME}"
pserver_id=$(python /root/k8s_tools.py fetch_id ${label})
cmd="${ENTRY} --ps_hosts=${g_pservers} --worker_hosts=${g_trainers} \
--job_name=${TF_JOB_NAME} --task_index=${pserver_id}"
stdbuf -oL sh -c "cd ${TRAINER_PACKAGE} && ${cmd}"
}
start_tf_trainer(){
wait_running_pods
label="tf-job-trainer=${JOB_NAME}"
trainer_id=$(python /root/k8s_tools.py fetch_id ${label})
cmd="${ENTRY} --ps_hosts=${g_pservers} --worker_hosts=${g_trainers} \
--job_name=${TF_JOB_NAME} --task_index=${trainer_id} --batch_size=${BATCH_SIZE}"
stdbuf -oL sh -c "cd ${TRAINER_PACKAGE} && ${cmd}"
check_trainer_ret $?
}
start_tf(){
if [[ "${TF_JOB_NAME}" == "worker" ]]; then
start_tf_trainer
else
start_tf_pserver
fi
}
usage() {
echo "usage: tf_k8s [<args>]:"
echo " start_tf Start tensorflow jobs"
}
case "$1" in
start_tf)
start_tf
;;
--help)
usage
;;
*)
usage
;;
esac