parent
541b42e6fb
commit
d3905fbc1e
@ -0,0 +1,12 @@
|
||||
#FROM paddlepaddle/paddlecloud-job
|
||||
#RUN mkdir -p /workspace
|
||||
#ADD reader.py /workspace/
|
||||
#RUN python /workspace/reader.py
|
||||
FROM python:2.7.14
|
||||
ADD *.whl /
|
||||
RUN pip install /*.whl && rm -f /*.whl
|
||||
ADD paddle_k8s /usr/bin
|
||||
ADD k8s_tools.py /root
|
||||
RUN pip install -U kubernetes opencv-python && apt-get update -y && apt-get install -y iputils-ping libgtk2.0-dev
|
||||
|
||||
ADD vgg16.py /workspace/
|
@ -0,0 +1,78 @@
|
||||
#!/bin/env python
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import socket
|
||||
from kubernetes import client, config
|
||||
PADDLE_JOB_NAME = os.getenv("PADDLE_JOB_NAME")
|
||||
NAMESPACE = os.getenv("NAMESPACE")
|
||||
PORT = os.getenv("PSERVER_PORT")
|
||||
if os.getenv("KUBERNETES_SERVICE_HOST", None):
|
||||
config.load_incluster_config()
|
||||
else:
|
||||
config.load_kube_config()
|
||||
v1 = client.CoreV1Api()
|
||||
|
||||
|
||||
def fetch_pods_info(label_selector):
|
||||
api_response = v1.list_namespaced_pod(
|
||||
namespace=NAMESPACE, pretty=True, label_selector=label_selector)
|
||||
pod_list = []
|
||||
for item in api_response.items:
|
||||
pod_list.append((item.status.phase, item.status.pod_ip))
|
||||
return pod_list
|
||||
|
||||
|
||||
def wait_pods_running(label_selector, desired):
|
||||
print "label selector: %s, desired: %s" % (label_selector, desired)
|
||||
while True:
|
||||
count = count_pods_by_phase(label_selector, 'Running')
|
||||
# NOTE: pods may be scaled.
|
||||
if count >= int(desired):
|
||||
break
|
||||
print 'current cnt: %d sleep for 5 seconds...' % count
|
||||
time.sleep(5)
|
||||
|
||||
def count_pods_by_phase(label_selector, phase):
|
||||
pod_list = fetch_pods_info(label_selector)
|
||||
filtered_pod_list = filter(lambda x: x[0] == phase, pod_list)
|
||||
return len(filtered_pod_list)
|
||||
|
||||
|
||||
def fetch_pserver_ips():
|
||||
label_selector = "paddle-job-pserver=%s" % PADDLE_JOB_NAME
|
||||
pod_list = fetch_pods_info(label_selector)
|
||||
pserver_ips = [item[1] for item in pod_list]
|
||||
return ",".join(pserver_ips)
|
||||
|
||||
def fetch_master_ip():
|
||||
label_selector = "paddle-job-master=%s" % PADDLE_JOB_NAME
|
||||
pod_list = fetch_pods_info(label_selector)
|
||||
master_ips = [item[1] for item in pod_list]
|
||||
return master_ips[0]
|
||||
|
||||
def fetch_trainer_id():
|
||||
label_selector = "paddle-job=%s" % PADDLE_JOB_NAME
|
||||
pod_list = fetch_pods_info(label_selector)
|
||||
trainer_ips = [item[1] for item in pod_list]
|
||||
trainer_ips.sort()
|
||||
local_ip = socket.gethostbyname(socket.gethostname())
|
||||
for i in xrange(len(trainer_ips)):
|
||||
if trainer_ips[i] == local_ip:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
command = sys.argv[1]
|
||||
if command == "fetch_pserver_ips":
|
||||
print fetch_pserver_ips()
|
||||
elif command == "fetch_trainer_id":
|
||||
print fetch_trainer_id()
|
||||
elif command == "fetch_master_ip":
|
||||
print fetch_master_ip()
|
||||
elif command == "count_pods_by_phase":
|
||||
print count_pods_by_phase(sys.argv[2], sys.argv[3])
|
||||
elif command == "wait_pods_running":
|
||||
wait_pods_running(sys.argv[2], sys.argv[3])
|
||||
|
@ -0,0 +1,200 @@
|
||||
#!/bin/bash
|
||||
start_pserver() {
|
||||
stdbuf -oL paddle pserver \
|
||||
--use_gpu=0 \
|
||||
--port=$PADDLE_INIT_PORT \
|
||||
--ports_num=$PADDLE_INIT_PORTS_NUM \
|
||||
--ports_num_for_sparse=$PADDLE_INIT_PORTS_NUM_FOR_SPARSE \
|
||||
--nics=$PADDLE_INIT_NICS \
|
||||
--comment=paddle_process_k8s \
|
||||
--num_gradient_servers=$PADDLE_INIT_NUM_GRADIENT_SERVERS
|
||||
}
|
||||
|
||||
start_new_pserver() {
|
||||
stdbuf -oL python /root/k8s_tools.py wait_pods_running paddle-job-master=${PADDLE_JOB_NAME} 1
|
||||
export MASTER_IP=$(python /root/k8s_tools.py fetch_master_ip)
|
||||
stdbuf -oL /usr/bin/pserver \
|
||||
-port=$PADDLE_INIT_PORT \
|
||||
-num-pservers=$PSERVERS \
|
||||
-log-level=debug \
|
||||
-etcd-endpoint=http://$MASTER_IP:2379
|
||||
}
|
||||
|
||||
start_master() {
|
||||
stdbuf -oL /usr/bin/master \
|
||||
-port=8080 \
|
||||
-chunk-per-task=1\
|
||||
-task-timout-dur=16s\
|
||||
-endpoints=http://127.0.0.1:2379
|
||||
}
|
||||
|
||||
check_failed_cnt() {
|
||||
max_failed=$1
|
||||
failed_count=$(python /root/k8s_tools.py count_pods_by_phase paddle-job=${PADDLE_JOB_NAME} Failed)
|
||||
if [ $failed_count -gt $max_failed ]; then
|
||||
stdbuf -oL echo "Failed trainer count beyond the threadhold: "$max_failed
|
||||
echo "Failed trainer count beyond the threshold: " $max_failed > /dev/termination-log
|
||||
exit 0
|
||||
fi
|
||||
}
|
||||
|
||||
check_trainer_ret() {
|
||||
ret=$1
|
||||
stdbuf -oL echo "job returned $ret...setting pod return message..."
|
||||
stdbuf -oL echo "==============================="
|
||||
|
||||
if [ $ret -eq 136 ] ; then
|
||||
echo "Error Arithmetic Operation(Floating Point Exception)" > /dev/termination-log
|
||||
elif [ $ret -eq 139 ] ; then
|
||||
echo "Segmentation Fault" > /dev/termination-log
|
||||
elif [ $ret -eq 1 ] ; then
|
||||
echo "General Error" > /dev/termination-log
|
||||
elif [ $ret -eq 134 ] ; then
|
||||
echo "Program Abort" > /dev/termination-log
|
||||
fi
|
||||
stdbuf -oL echo "termination log wroted..."
|
||||
exit $ret
|
||||
}
|
||||
|
||||
start_fluid_process() {
|
||||
stdbuf -oL python /root/k8s_tools.py wait_pods_running paddle-job-pserver=${PADDLE_JOB_NAME} ${PSERVERS}
|
||||
if [ "${TRAINING_ROLE}" == "TRAINER" ]; then
|
||||
check_failed_cnt ${TRAINERS}
|
||||
sleep 5
|
||||
stdbuf -oL python /root/k8s_tools.py wait_pods_running paddle-job-master=${PADDLE_JOB_NAME} 1
|
||||
export PADDLE_INIT_TRAINER_ID=$(python /root/k8s_tools.py fetch_trainer_id)
|
||||
fi
|
||||
export PADDLE_INIT_PSERVERS=$(python /root/k8s_tools.py fetch_pserver_ips)
|
||||
stdbuf -oL sh -c "${ENTRY}"
|
||||
check_trainer_ret $?
|
||||
}
|
||||
|
||||
start_new_trainer() {
|
||||
# FIXME(Yancey1989): use command-line interface to configure the max failed count
|
||||
check_failed_cnt ${TRAINERS}
|
||||
stdbuf -oL python /root/k8s_tools.py wait_pods_running paddle-job-pserver=${PADDLE_JOB_NAME} ${PSERVERS}
|
||||
sleep 5
|
||||
stdbuf -oL python /root/k8s_tools.py wait_pods_running paddle-job-master=${PADDLE_JOB_NAME} 1
|
||||
export MASTER_IP=$(python /root/k8s_tools.py fetch_master_ip)
|
||||
export ETCD_IP="$MASTER_IP"
|
||||
|
||||
# NOTE: $TRAINER_PACKAGE may be large, do not copy
|
||||
export PYTHONPATH=$TRAINER_PACKAGE:$PYTHONPATH
|
||||
cd $TRAINER_PACKAGE
|
||||
|
||||
stdbuf -oL echo "Starting training job: " $TRAINER_PACKAGE, "num_gradient_servers:" \
|
||||
$PADDLE_INIT_NUM_GRADIENT_SERVERS, "version: " $1
|
||||
|
||||
stdbuf -oL sh -c "${ENTRY}"
|
||||
check_trainer_ret $?
|
||||
}
|
||||
|
||||
start_trainer() {
|
||||
# paddle v1 and V2 distributed training does not allow any trainer failed.
|
||||
check_failed_cnt 0
|
||||
stdbuf -oL python /root/k8s_tools.py wait_pods_running paddle-job-pserver=${PADDLE_JOB_NAME} ${PSERVERS}
|
||||
stdbuf -oL python /root/k8s_tools.py wait_pods_running paddle-job=${PADDLE_JOB_NAME} ${TRAINERS}
|
||||
|
||||
export PADDLE_INIT_PSERVERS=$(python /root/k8s_tools.py fetch_pserver_ips)
|
||||
export PADDLE_INIT_TRAINER_ID=$(python /root/k8s_tools.py fetch_trainer_id)
|
||||
stdbuf -oL echo $PADDLE_INIT_TRAINER_ID > /trainer_id
|
||||
# FIXME: /trainer_count = PADDLE_INIT_NUM_GRADIENT_SERVERS
|
||||
stdbuf -oL echo $PADDLE_INIT_NUM_GRADIENT_SERVERS > /trainer_count
|
||||
|
||||
# NOTE: $TRAINER_PACKAGE may be large, do not copy
|
||||
export PYTHONPATH=$TRAINER_PACKAGE:$PYTHONPATH
|
||||
cd $TRAINER_PACKAGE
|
||||
|
||||
stdbuf -oL echo "Starting training job: " $TRAINER_PACKAGE, "num_gradient_servers:" \
|
||||
$PADDLE_INIT_NUM_GRADIENT_SERVERS, "trainer_id: " $PADDLE_INIT_TRAINER_ID, \
|
||||
"version: " $1
|
||||
|
||||
# FIXME: If we use the new PServer by Golang, add Kubernetes healthz
|
||||
# to wait PServer process get ready.Now only sleep 20 seconds.
|
||||
sleep 20
|
||||
|
||||
case "$1" in
|
||||
"v1")
|
||||
FILE_COUNT=$(wc -l $TRAIN_LIST | awk '{print $1}')
|
||||
if [ $FILE_COUNT -le $PADDLE_INIT_NUM_GRADIENT_SERVERS ]; then
|
||||
echo "file count less than trainers"
|
||||
check_trainer_ret 0
|
||||
fi
|
||||
let lines_per_node="$FILE_COUNT / ($PADDLE_INIT_NUM_GRADIENT_SERVERS + 1)"
|
||||
echo "spliting file to" $lines_per_node
|
||||
cp $TRAIN_LIST /
|
||||
cd /
|
||||
split -l $lines_per_node -d -a 3 $TRAIN_LIST train.list
|
||||
CURRENT_LIST=$(printf "train.list%03d" $PADDLE_INIT_TRAINER_ID)
|
||||
# always use /train.list for paddle v1 for each node.
|
||||
echo "File for current node ${CURRENT_LIST}"
|
||||
sleep 10
|
||||
cp $CURRENT_LIST train.list
|
||||
|
||||
cd $TRAINER_PACKAGE
|
||||
|
||||
stdbuf -oL paddle train \
|
||||
--port=$PADDLE_INIT_PORT \
|
||||
--nics=$PADDLE_INIT_NICS \
|
||||
--ports_num=$PADDLE_INIT_PORTS_NUM \
|
||||
--ports_num_for_sparse=$PADDLE_INIT_PORTS_NUM_FOR_SPARSE \
|
||||
--num_passes=$PADDLE_INIT_NUM_PASSES \
|
||||
--trainer_count=$PADDLE_INIT_TRAINER_COUNT \
|
||||
--saving_period=1 \
|
||||
--log_period=20 \
|
||||
--local=0 \
|
||||
--rdma_tcp=tcp \
|
||||
--config=$TOPOLOGY \
|
||||
--use_gpu=$PADDLE_INIT_USE_GPU \
|
||||
--trainer_id=$PADDLE_INIT_TRAINER_ID \
|
||||
--save_dir=$OUTPUT \
|
||||
--pservers=$PADDLE_INIT_PSERVERS \
|
||||
--num_gradient_servers=$PADDLE_INIT_NUM_GRADIENT_SERVERS
|
||||
# paddle v1 API does not allow any trainer failed.
|
||||
check_trainer_ret $?
|
||||
;;
|
||||
"v2")
|
||||
stdbuf -oL sh -c "${ENTRY}"
|
||||
# paddle v2 API does not allow any trainer failed.
|
||||
check_trainer_ret $?
|
||||
;;
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
usage() {
|
||||
echo "usage: paddle_k8s [<args>]:"
|
||||
echo " start_trainer [v1|v2] Start a trainer process with v1 or v2 API"
|
||||
echo " start_pserver Start a pserver process"
|
||||
echo " start_new_pserver Start a new pserver process"
|
||||
echo " start_new_trainer Start a new triner process"
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
start_pserver)
|
||||
start_pserver
|
||||
;;
|
||||
start_trainer)
|
||||
start_trainer $2
|
||||
;;
|
||||
start_new_trainer)
|
||||
start_new_trainer
|
||||
;;
|
||||
start_new_pserver)
|
||||
start_new_pserver
|
||||
;;
|
||||
start_master)
|
||||
start_master
|
||||
;;
|
||||
start_fluid)
|
||||
start_fluid_process
|
||||
;;
|
||||
--help)
|
||||
usage
|
||||
;;
|
||||
*)
|
||||
usage
|
||||
;;
|
||||
esac
|
||||
|
@ -0,0 +1,72 @@
|
||||
apiVersion: extensions/v1beta1
|
||||
kind: ReplicaSet
|
||||
metadata:
|
||||
name: vgg16job-pserver
|
||||
spec:
|
||||
replicas: 10
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
paddle-job-pserver: vgg16job
|
||||
spec:
|
||||
hostNetwork: true
|
||||
imagePullSecrets:
|
||||
- name: job-registry-secret
|
||||
containers:
|
||||
- name: pserver
|
||||
image: "registry.baidu.com/paddlepaddle/rawjob:vgg16_fluid"
|
||||
imagePullPolicy: Always
|
||||
ports:
|
||||
- name: jobport-30236
|
||||
containerPort: 30236
|
||||
env:
|
||||
- name: PADDLE_JOB_NAME
|
||||
value: vgg16job
|
||||
- name: MKL_NUM_THREADS
|
||||
value: "1"
|
||||
- name: TRAINING_ROLE
|
||||
value: "PSERVER"
|
||||
- name: TRAINERS
|
||||
value: "20"
|
||||
- name: PSERVERS
|
||||
value: "10"
|
||||
- name: TOPOLOGY
|
||||
value: ""
|
||||
- name: ENTRY
|
||||
value: "MKL_NUM_THREADS=1 python /workspace/vgg16.py --local 0"
|
||||
- name: TRAINER_PACKAGE
|
||||
value: "/workspace"
|
||||
- name: PADDLE_INIT_PORT
|
||||
value: "30236"
|
||||
- name: PADDLE_INIT_NICS
|
||||
value: "xgbe0"
|
||||
- name: PADDLE_INIT_TRAINER_COUNT
|
||||
value: "1"
|
||||
- name: PADDLE_INIT_PORTS_NUM
|
||||
value: "1"
|
||||
- name: PADDLE_INIT_PORTS_NUM_FOR_SPARSE
|
||||
value: "1"
|
||||
- name: PADDLE_INIT_NUM_GRADIENT_SERVERS
|
||||
value: "20"
|
||||
- name: PADDLE_INIT_NUM_PASSES
|
||||
value: "1"
|
||||
- name: PADDLE_INIT_USE_GPU
|
||||
value: "0"
|
||||
- name: LD_LIBRARY_PATH
|
||||
value: "/usr/local/nvidia/lib64"
|
||||
- name: NAMESPACE
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: "metadata.namespace"
|
||||
- name: POD_IP
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: "status.podIP"
|
||||
command: ["paddle_k8s", "start_fluid"]
|
||||
resources:
|
||||
requests:
|
||||
memory: 10Gi
|
||||
cpu: 4
|
||||
limits:
|
||||
memory: 10Gi
|
||||
cpu: 4
|
@ -0,0 +1,2 @@
|
||||
import paddle.v2 as paddle
|
||||
paddle.dataset.cifar.train10()
|
@ -0,0 +1,69 @@
|
||||
apiVersion: batch/v1
|
||||
kind: Job
|
||||
metadata:
|
||||
name: vgg16job-trainer
|
||||
spec:
|
||||
parallelism: 20
|
||||
completions: 20
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
paddle-job: vgg16job
|
||||
spec:
|
||||
imagePullSecrets:
|
||||
- name: job-registry-secret
|
||||
hostNetwork: true
|
||||
containers:
|
||||
- name: trainer
|
||||
image: "registry.baidu.com/paddlepaddle/rawjob:vgg16_fluid"
|
||||
imagePullPolicy: Always
|
||||
command: ["paddle_k8s", "start_trainer", "v2"]
|
||||
env:
|
||||
- name: PADDLE_JOB_NAME
|
||||
value: vgg16job
|
||||
- name: TRAINING_ROLE
|
||||
value: "TRAINER"
|
||||
- name: TRAINERS
|
||||
value: "20"
|
||||
- name: PSERVERS
|
||||
value: "10"
|
||||
- name: TOPOLOGY
|
||||
value: ""
|
||||
- name: ENTRY
|
||||
value: "cd /workspace && MKL_NUM_THREADS=1 python /workspace/vgg16.py"
|
||||
- name: TRAINER_PACKAGE
|
||||
value: "/workspace"
|
||||
- name: PADDLE_INIT_PORT
|
||||
value: "30236"
|
||||
- name: PADDLE_INIT_NICS
|
||||
value: "xgbe0"
|
||||
- name: PADDLE_INIT_TRAINER_COUNT
|
||||
value: "1"
|
||||
- name: PADDLE_INIT_PORTS_NUM
|
||||
value: "1"
|
||||
- name: PADDLE_INIT_PORTS_NUM_FOR_SPARSE
|
||||
value: "1"
|
||||
- name: PADDLE_INIT_NUM_GRADIENT_SERVERS
|
||||
value: "20"
|
||||
- name: PADDLE_INIT_NUM_PASSES
|
||||
value: "1"
|
||||
- name: PADDLE_INIT_USE_GPU
|
||||
value: "0"
|
||||
- name: LD_LIBRARY_PATH
|
||||
value: "/usr/local/nvidia/lib64"
|
||||
- name: NAMESPACE
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: "metadata.namespace"
|
||||
- name: POD_IP
|
||||
valueFrom:
|
||||
fieldRef:
|
||||
fieldPath: "status.podIP"
|
||||
resources:
|
||||
requests:
|
||||
memory: 40Gi
|
||||
cpu: 2
|
||||
limits:
|
||||
memory: 40Gi
|
||||
cpu: 2
|
||||
restartPolicy: Never
|
@ -0,0 +1,248 @@
|
||||
"""VGG16 benchmark in Fluid"""
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
import paddle.v2.fluid.core as core
|
||||
import argparse
|
||||
import functools
|
||||
import os
|
||||
|
||||
def str2bool(v):
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
'--batch_size', type=int, default=128, help="Batch size for training.")
|
||||
parser.add_argument(
|
||||
'--learning_rate',
|
||||
type=float,
|
||||
default=1e-3,
|
||||
help="Learning rate for training.")
|
||||
parser.add_argument('--num_passes', type=int, default=50, help="No. of passes.")
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
type=str,
|
||||
default='CPU',
|
||||
choices=['CPU', 'GPU'],
|
||||
help="The device type.")
|
||||
parser.add_argument(
|
||||
'--data_format',
|
||||
type=str,
|
||||
default='NCHW',
|
||||
choices=['NCHW', 'NHWC'],
|
||||
help='The data order, now only support NCHW.')
|
||||
parser.add_argument(
|
||||
'--data_set',
|
||||
type=str,
|
||||
default='cifar10',
|
||||
choices=['cifar10', 'flowers'],
|
||||
help='Optional dataset for benchmark.')
|
||||
parser.add_argument(
|
||||
'--local',
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help='Whether to run as local mode.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def vgg16_bn_drop(input):
|
||||
def conv_block(input, num_filter, groups, dropouts):
|
||||
return fluid.nets.img_conv_group(
|
||||
input=input,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
conv_num_filter=[num_filter] * groups,
|
||||
conv_filter_size=3,
|
||||
conv_act='relu',
|
||||
conv_with_batchnorm=True,
|
||||
conv_batchnorm_drop_rate=dropouts,
|
||||
pool_type='max')
|
||||
|
||||
conv1 = conv_block(input, 64, 2, [0.3, 0])
|
||||
conv2 = conv_block(conv1, 128, 2, [0.4, 0])
|
||||
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0])
|
||||
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0])
|
||||
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0])
|
||||
|
||||
drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
|
||||
fc1 = fluid.layers.fc(input=drop, size=512, act=None)
|
||||
bn = fluid.layers.batch_norm(input=fc1, act='relu')
|
||||
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
|
||||
fc2 = fluid.layers.fc(input=drop2, size=512, act=None)
|
||||
return fc2
|
||||
|
||||
|
||||
def main():
|
||||
if args.data_set == "cifar10":
|
||||
classdim = 10
|
||||
if args.data_format == 'NCHW':
|
||||
data_shape = [3, 32, 32]
|
||||
else:
|
||||
data_shape = [32, 32, 3]
|
||||
else:
|
||||
classdim = 102
|
||||
if args.data_format == 'NCHW':
|
||||
data_shape = [3, 224, 224]
|
||||
else:
|
||||
data_shape = [224, 224, 3]
|
||||
|
||||
# Input data
|
||||
images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32')
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
|
||||
# Train program
|
||||
net = vgg16_bn_drop(images)
|
||||
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
# Evaluator
|
||||
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
|
||||
|
||||
# inference program
|
||||
inference_program = fluid.default_main_program().clone()
|
||||
with fluid.program_guard(inference_program):
|
||||
test_target = accuracy.metrics + accuracy.states
|
||||
inference_program = fluid.io.get_inference_program(test_target)
|
||||
|
||||
# Optimization
|
||||
optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate)
|
||||
optimize_ops, params_grads = optimizer.minimize(avg_cost)
|
||||
|
||||
# Initialize executor
|
||||
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
|
||||
# test
|
||||
def test(exe):
|
||||
accuracy.reset(exe)
|
||||
for batch_id, data in enumerate(test_reader()):
|
||||
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
|
||||
data)).astype("float32")
|
||||
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
|
||||
y_data = y_data.reshape([-1, 1])
|
||||
|
||||
exe.run(inference_program,
|
||||
feed={"pixel": img_data,
|
||||
"label": y_data})
|
||||
|
||||
return accuracy.eval(exe)
|
||||
|
||||
def train_loop(exe, trainer_prog):
|
||||
iters = 0
|
||||
for pass_id in range(args.num_passes):
|
||||
# train
|
||||
start_time = time.time()
|
||||
num_samples = 0
|
||||
accuracy.reset(exe)
|
||||
for batch_id, data in enumerate(train_reader()):
|
||||
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
|
||||
data)).astype("float32")
|
||||
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
|
||||
y_data = y_data.reshape([-1, 1])
|
||||
|
||||
loss, acc = exe.run(trainer_prog,
|
||||
feed={"pixel": img_data,
|
||||
"label": y_data},
|
||||
fetch_list=[avg_cost] + accuracy.metrics)
|
||||
iters += 1
|
||||
num_samples += len(data)
|
||||
print(
|
||||
"Pass = %d, Iters = %d, Loss = %f, Accuracy = %f" %
|
||||
(pass_id, iters, loss, acc)
|
||||
) # The accuracy is the accumulation of batches, but not the current batch.
|
||||
|
||||
pass_elapsed = time.time() - start_time
|
||||
pass_train_acc = accuracy.eval(exe)
|
||||
pass_test_acc = test(exe)
|
||||
print(
|
||||
"Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n"
|
||||
% (pass_id, num_samples / pass_elapsed, pass_train_acc,
|
||||
pass_test_acc))
|
||||
|
||||
if args.local:
|
||||
# Parameter initialization
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
# data reader
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.cifar.train10()
|
||||
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(),
|
||||
buf_size=5120),
|
||||
batch_size=args.batch_size)
|
||||
test_reader = paddle.batch(
|
||||
paddle.dataset.cifar.test10()
|
||||
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
|
||||
batch_size=args.batch_size)
|
||||
train_loop(exe, fluid.default_main_program())
|
||||
else:
|
||||
pserver_ips = os.getenv("PADDLE_INIT_PSERVERS") # all pserver endpoints
|
||||
eplist = []
|
||||
for ip in pserver_ips.split(","):
|
||||
eplist.append(':'.join([ip, "6174"]))
|
||||
pserver_endpoints = ",".join(eplist)
|
||||
print("pserver endpoints: ", pserver_endpoints)
|
||||
trainers = int(os.getenv("TRAINERS")) # total trainer count
|
||||
current_endpoint = os.getenv("POD_IP") + ":6174" # current pserver endpoint
|
||||
training_role = os.getenv("TRAINING_ROLE",
|
||||
"TRAINER") # get the training role: trainer/pserver
|
||||
t = fluid.DistributeTranspiler()
|
||||
t.transpile(
|
||||
optimize_ops, params_grads, pservers=pserver_endpoints, trainers=trainers)
|
||||
|
||||
if training_role == "PSERVER":
|
||||
if not current_endpoint:
|
||||
print("need env SERVER_ENDPOINT")
|
||||
exit(1)
|
||||
pserver_prog = t.get_pserver_program(current_endpoint)
|
||||
pserver_startup = t.get_startup_program(current_endpoint, pserver_prog)
|
||||
print("starting server side startup")
|
||||
exe.run(pserver_startup)
|
||||
print("starting parameter server...")
|
||||
exe.run(pserver_prog)
|
||||
elif training_role == "TRAINER":
|
||||
# Parameter initialization
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
# data reader
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.cifar.train10()
|
||||
if args.data_set == 'cifar10' else paddle.dataset.flowers.train(),
|
||||
buf_size=5120),
|
||||
batch_size=args.batch_size)
|
||||
test_reader = paddle.batch(
|
||||
paddle.dataset.cifar.test10()
|
||||
if args.data_set == 'cifar10' else paddle.dataset.flowers.test(),
|
||||
batch_size=args.batch_size)
|
||||
|
||||
trainer_prog = t.get_trainer_program()
|
||||
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
|
||||
# TODO(typhoonzero): change trainer startup program to fetch parameters from pserver
|
||||
exe.run(fluid.default_startup_program())
|
||||
train_loop(exe, trainer_prog)
|
||||
else:
|
||||
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
|
||||
|
||||
|
||||
def print_arguments():
|
||||
print('----------- Configuration Arguments -----------')
|
||||
for arg, value in sorted(vars(args).iteritems()):
|
||||
print('%s: %s' % (arg, value))
|
||||
print('------------------------------------------------')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_arguments()
|
||||
main()
|
Loading…
Reference in new issue