|
|
|
@ -31,10 +31,15 @@ type Chunk struct {
|
|
|
|
|
Index recordio.Index // chunk index
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TaskMeta is a struct which stores task's meta info.
|
|
|
|
|
type TaskMeta struct {
|
|
|
|
|
ID int
|
|
|
|
|
Epoch int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Task is the basic unit of data instances assigned to trainers.
|
|
|
|
|
type Task struct {
|
|
|
|
|
ID int
|
|
|
|
|
Epoch int
|
|
|
|
|
Meta TaskMeta
|
|
|
|
|
Chunks []Chunk
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -74,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
|
|
|
|
|
var cur taskEntry
|
|
|
|
|
for i, c := range chunks {
|
|
|
|
|
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
|
|
|
|
|
cur.Task.ID = id
|
|
|
|
|
cur.Task.Meta.ID = id
|
|
|
|
|
id++
|
|
|
|
|
result = append(result, cur)
|
|
|
|
|
cur.Task.Chunks = nil
|
|
|
|
@ -84,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(cur.Task.Chunks) > 0 {
|
|
|
|
|
cur.Task.ID = id
|
|
|
|
|
cur.Task.Meta.ID = id
|
|
|
|
|
result = append(result, cur)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -258,8 +263,8 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *Service) procFailedTask(t taskEntry, epoch int) {
|
|
|
|
|
if t.Task.Epoch != epoch {
|
|
|
|
|
func (s *Service) processFailedTask(t taskEntry, epoch int) {
|
|
|
|
|
if t.Task.Meta.Epoch != epoch {
|
|
|
|
|
// new epoch, task launched after the
|
|
|
|
|
// schedule of this timeout check or failed status report.
|
|
|
|
|
return
|
|
|
|
@ -272,7 +277,7 @@ func (s *Service) procFailedTask(t taskEntry, epoch int) {
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
delete(s.taskQueues.Pending, t.Task.ID)
|
|
|
|
|
delete(s.taskQueues.Pending, t.Task.Meta.ID)
|
|
|
|
|
|
|
|
|
|
t.NumFailure++
|
|
|
|
|
if t.NumFailure > s.failureMax {
|
|
|
|
@ -296,7 +301,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.procFailedTask(t, epoch)
|
|
|
|
|
s.processFailedTask(t, epoch)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -345,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
t := s.taskQueues.Todo[0]
|
|
|
|
|
t.Task.Epoch++
|
|
|
|
|
t.Task.Meta.Epoch++
|
|
|
|
|
s.taskQueues.Todo = s.taskQueues.Todo[1:]
|
|
|
|
|
s.taskQueues.Pending[t.Task.ID] = t
|
|
|
|
|
s.taskQueues.Pending[t.Task.Meta.ID] = t
|
|
|
|
|
err := s.snapshot()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*task = t.Task
|
|
|
|
|
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t)
|
|
|
|
|
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Meta)
|
|
|
|
|
|
|
|
|
|
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Task.Epoch))
|
|
|
|
|
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -373,7 +378,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
|
|
|
|
|
if !ok {
|
|
|
|
|
err := errors.New("pending task not found")
|
|
|
|
|
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
|
|
|
|
|
return err
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// task finished, reset timeout
|
|
|
|
@ -396,14 +401,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TaskID is a struct which client uses for reports failure.
|
|
|
|
|
type TaskID struct {
|
|
|
|
|
ID int
|
|
|
|
|
Epoch int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TaskFailed tells the service that a task is failed.
|
|
|
|
|
func (s *Service) TaskFailed(taskID TaskID, dummy *int) error {
|
|
|
|
|
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
|
|
|
|
|
select {
|
|
|
|
|
case <-s.ready:
|
|
|
|
|
}
|
|
|
|
@ -411,13 +410,13 @@ func (s *Service) TaskFailed(taskID TaskID, dummy *int) error {
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
t, ok := s.taskQueues.Pending[taskID.ID]
|
|
|
|
|
t, ok := s.taskQueues.Pending[meta.ID]
|
|
|
|
|
if !ok {
|
|
|
|
|
err := errors.New("pending task not found")
|
|
|
|
|
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", taskID)
|
|
|
|
|
return err
|
|
|
|
|
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Meta)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.procFailedTask(t, taskID.Epoch)
|
|
|
|
|
s.processFailedTask(t, meta.Epoch)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|