Merge pull request #2245 from helinwang/master_server
implement task handling for master server's servicerefactor_docs
commit
7af02682f7
@ -0,0 +1,93 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/namsral/flag"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/paddle/go/master"
|
||||
"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 8080, "port of the master server.")
|
||||
dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")
|
||||
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
|
||||
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
|
||||
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
|
||||
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
|
||||
flag.Parse()
|
||||
|
||||
if *dataset == "" {
|
||||
panic("no dataset specified.")
|
||||
}
|
||||
|
||||
if *faultTolerance {
|
||||
panic("fault tolernance not implemented.")
|
||||
}
|
||||
|
||||
var chunks []master.Chunk
|
||||
var paths []string
|
||||
ss := strings.Split(*dataset, ",")
|
||||
fmt.Println(ss)
|
||||
for _, s := range ss {
|
||||
match, err := filepath.Glob(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
paths = append(paths, match...)
|
||||
}
|
||||
|
||||
if len(paths) == 0 {
|
||||
panic("no valid datset specified.")
|
||||
}
|
||||
|
||||
idx := 0
|
||||
for _, path := range paths {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
index, err := recordio.LoadIndex(f)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
count := index.NumChunks()
|
||||
for i := 0; i < count; i++ {
|
||||
chunk := master.Chunk{
|
||||
Idx: idx,
|
||||
Path: path,
|
||||
Index: *index.ChunkIndex(i),
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
}
|
||||
|
||||
s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
|
||||
err := rpc.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rpc.HandleHTTP()
|
||||
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = http.Serve(l, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
@ -0,0 +1,178 @@
|
||||
package master
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
|
||||
)
|
||||
|
||||
const (
|
||||
targetTaskCount = 300
|
||||
)
|
||||
|
||||
// errors
|
||||
var (
|
||||
ErrNoMoreTask = errors.New("no more task for current pass")
|
||||
ErrPendingTaskNotFound = errors.New("pending task not found")
|
||||
)
|
||||
|
||||
// Service is the master server service.
|
||||
type Service struct {
|
||||
timeoutDur time.Duration
|
||||
timeoutMax int
|
||||
|
||||
mu sync.Mutex
|
||||
taskQueues taskQueues
|
||||
}
|
||||
|
||||
// Recover recovers service state from etcd.
|
||||
func Recover() (*Service, error) {
|
||||
// TODO(helin): recover from snapshot state from etcd.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
|
||||
id := 0
|
||||
if chunksPerTask <= 0 {
|
||||
chunksPerTask = 1
|
||||
}
|
||||
|
||||
var result []taskEntry
|
||||
var cur taskEntry
|
||||
for i, c := range chunks {
|
||||
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
|
||||
cur.Task.ID = id
|
||||
id++
|
||||
result = append(result, cur)
|
||||
cur.Task.Chunks = nil
|
||||
}
|
||||
|
||||
cur.Task.Chunks = append(cur.Task.Chunks, c)
|
||||
}
|
||||
|
||||
if len(cur.Task.Chunks) > 0 {
|
||||
cur.Task.ID = id
|
||||
id++
|
||||
result = append(result, cur)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// NewService creates a new service.
|
||||
func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service {
|
||||
s := &Service{}
|
||||
s.timeoutDur = timeoutDur
|
||||
s.timeoutMax = timeoutMax
|
||||
s.taskQueues = taskQueues{}
|
||||
s.taskQueues.Pending = make(map[int]taskEntry)
|
||||
s.taskQueues.Todo = partition(chunks, chunksPerTask)
|
||||
return s
|
||||
}
|
||||
|
||||
// Chunk is a chunk of data consisted of several data instances.
|
||||
type Chunk struct {
|
||||
Idx int // index of the chunk within the file
|
||||
Path string
|
||||
Index recordio.Index // block index
|
||||
}
|
||||
|
||||
// Task is the basic unit of data instances assigned to trainers.
|
||||
type Task struct {
|
||||
ID int
|
||||
Chunks []Chunk
|
||||
}
|
||||
|
||||
type taskEntry struct {
|
||||
Epoch int
|
||||
NumTimeout int
|
||||
Task Task
|
||||
}
|
||||
|
||||
type taskQueues struct {
|
||||
Todo []taskEntry
|
||||
Pending map[int]taskEntry // map from task ID to task entry
|
||||
Done []taskEntry
|
||||
Failed []Task
|
||||
}
|
||||
|
||||
// *must* be called with s.mu being held.
|
||||
func (s *Service) snapshot() error {
|
||||
// TODO(helin): snapshot state on etcd.
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTask gets a new task from the service.
|
||||
func (s *Service) GetTask(dummy int, task *Task) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if len(s.taskQueues.Todo) == 0 {
|
||||
return ErrNoMoreTask
|
||||
}
|
||||
|
||||
t := s.taskQueues.Todo[0]
|
||||
t.Epoch++
|
||||
s.taskQueues.Todo = s.taskQueues.Todo[1:]
|
||||
s.taskQueues.Pending[t.Task.ID] = t
|
||||
err := s.snapshot()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() {
|
||||
return func() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
t, ok := s.taskQueues.Pending[taskID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if t.Epoch != epoch {
|
||||
// new epoch, task launched after the
|
||||
// schedule of this timeout check.
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := s.snapshot()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
}()
|
||||
|
||||
delete(s.taskQueues.Pending, t.Task.ID)
|
||||
|
||||
t.NumTimeout++
|
||||
if t.NumTimeout > s.timeoutMax {
|
||||
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
|
||||
return
|
||||
}
|
||||
|
||||
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
|
||||
}
|
||||
}(t.Task.ID, t.Epoch))
|
||||
return nil
|
||||
}
|
||||
|
||||
// TaskFinished tell the service that a task is finished.
|
||||
func (s *Service) TaskFinished(taskID int, dummy *int) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
t, ok := s.taskQueues.Pending[taskID]
|
||||
if !ok {
|
||||
return ErrPendingTaskNotFound
|
||||
}
|
||||
|
||||
// task finished, reset timeout
|
||||
t.NumTimeout = 0
|
||||
s.taskQueues.Done = append(s.taskQueues.Done, t)
|
||||
delete(s.taskQueues.Pending, taskID)
|
||||
return s.snapshot()
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
package master
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPartitionCount(t *testing.T) {
|
||||
cs := make([]Chunk, 100)
|
||||
ts := partition(cs, 5)
|
||||
if len(ts) != 20 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
cs = make([]Chunk, 101)
|
||||
ts = partition(cs, 5)
|
||||
if len(ts) != 21 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
ts = partition(cs, 1)
|
||||
if len(ts) != 101 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
ts = partition(cs, 0)
|
||||
if len(ts) != 101 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartionIndex(t *testing.T) {
|
||||
cs := make([]Chunk, 100)
|
||||
ts := partition(cs, 20)
|
||||
for i := range ts {
|
||||
if ts[i].Task.ID != i {
|
||||
t.Error(ts[i], i)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in new issue