|
|
|
@ -2,12 +2,13 @@ package master
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"errors"
|
|
|
|
|
"log"
|
|
|
|
|
"os"
|
|
|
|
|
"path/filepath"
|
|
|
|
|
"sync"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
|
|
|
|
|
|
"github.com/PaddlePaddle/recordio"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -112,7 +113,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(paths) == 0 {
|
|
|
|
|
return nil, errors.New("no valid datset specified")
|
|
|
|
|
return nil, errors.New("no valid dataset specified")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, path := range paths {
|
|
|
|
@ -170,6 +171,7 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
|
|
|
|
|
|
|
|
|
|
err = s.snapshot()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -178,6 +180,43 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
|
|
|
|
|
return func() {
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
t, ok := s.taskQueues.Pending[taskID]
|
|
|
|
|
if !ok {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if t.Epoch != epoch {
|
|
|
|
|
// new epoch, task launched after the
|
|
|
|
|
// schedule of this timeout check.
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
|
err := s.snapshot()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
delete(s.taskQueues.Pending, t.Task.ID)
|
|
|
|
|
|
|
|
|
|
t.NumTimeout++
|
|
|
|
|
if t.NumTimeout > s.timeoutMax {
|
|
|
|
|
log.Warningf("Task %v failed %d times, discard.\n", t.Task, t.NumTimeout)
|
|
|
|
|
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.Warningf("Task %v failed %d times, retry.\n", t.Task, t.NumTimeout)
|
|
|
|
|
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetTask gets a new task from the service.
|
|
|
|
|
func (s *Service) GetTask(dummy int, task *Task) error {
|
|
|
|
|
select {
|
|
|
|
@ -190,19 +229,25 @@ func (s *Service) GetTask(dummy int, task *Task) error {
|
|
|
|
|
if len(s.taskQueues.Todo) == 0 {
|
|
|
|
|
if len(s.taskQueues.Done) == 0 {
|
|
|
|
|
if len(s.taskQueues.Pending) == 0 {
|
|
|
|
|
return errors.New("all task failed")
|
|
|
|
|
err := errors.New("all task failed")
|
|
|
|
|
log.Warningln(err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(helin): client need to retry in this
|
|
|
|
|
// error case. Gotcha: RPC client can't
|
|
|
|
|
// compare returned error with predefined
|
|
|
|
|
// errors like io.EOF. Because interface don't
|
|
|
|
|
// errors like io.EOF, because interface don't
|
|
|
|
|
// have same dynamic value when in different
|
|
|
|
|
// process.
|
|
|
|
|
return errors.New("no more available task")
|
|
|
|
|
// process. So we need to figure out a way for
|
|
|
|
|
// client to check this error correctly.
|
|
|
|
|
err := errors.New("no more available task")
|
|
|
|
|
log.Warningln(err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
s.taskQueues.Todo = s.taskQueues.Done
|
|
|
|
|
s.taskQueues.Todo = nil
|
|
|
|
|
s.taskQueues.Done = nil
|
|
|
|
|
log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
t := s.taskQueues.Todo[0]
|
|
|
|
@ -215,41 +260,9 @@ func (s *Service) GetTask(dummy int, task *Task) error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*task = t.Task
|
|
|
|
|
log.Infof("Task #%d dispatched\n", task.ID)
|
|
|
|
|
|
|
|
|
|
time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() {
|
|
|
|
|
return func() {
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
t, ok := s.taskQueues.Pending[taskID]
|
|
|
|
|
if !ok {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if t.Epoch != epoch {
|
|
|
|
|
// new epoch, task launched after the
|
|
|
|
|
// schedule of this timeout check.
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
|
err := s.snapshot()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Println(err)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
delete(s.taskQueues.Pending, t.Task.ID)
|
|
|
|
|
|
|
|
|
|
t.NumTimeout++
|
|
|
|
|
if t.NumTimeout > s.timeoutMax {
|
|
|
|
|
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
|
|
|
|
|
}
|
|
|
|
|
}(t.Task.ID, t.Epoch))
|
|
|
|
|
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -262,9 +275,13 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
log.Infof("Task %d finished\n", taskID)
|
|
|
|
|
|
|
|
|
|
t, ok := s.taskQueues.Pending[taskID]
|
|
|
|
|
if !ok {
|
|
|
|
|
return errors.New("pending task not found")
|
|
|
|
|
err := errors.New("pending task not found")
|
|
|
|
|
log.Warningln(err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// task finished, reset timeout
|
|
|
|
@ -272,10 +289,15 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
|
|
|
|
|
s.taskQueues.Done = append(s.taskQueues.Done, t)
|
|
|
|
|
delete(s.taskQueues.Pending, taskID)
|
|
|
|
|
|
|
|
|
|
if len(s.taskQueues.Pending) == 0 {
|
|
|
|
|
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
|
|
|
|
|
log.Infoln("No more todo and pending task, start a new pass.")
|
|
|
|
|
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
|
|
|
|
|
s.taskQueues.Done = nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return s.snapshot()
|
|
|
|
|
err := s.snapshot()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|