|
|
|
@ -11,10 +11,6 @@ import (
|
|
|
|
|
"github.com/PaddlePaddle/recordio"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
targetTaskCount = 300
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Service is the master server service.
|
|
|
|
|
type Service struct {
|
|
|
|
|
chunksPerTask int
|
|
|
|
@ -23,7 +19,7 @@ type Service struct {
|
|
|
|
|
ready chan struct{}
|
|
|
|
|
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
initBegan bool
|
|
|
|
|
initDone bool
|
|
|
|
|
taskQueues taskQueues
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -104,54 +100,35 @@ func (s *Service) snapshot() error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// SetDataset sets dataset to dispatch for the master server.
|
|
|
|
|
//
|
|
|
|
|
// SetDataset can be call multiple times. But only the first call will
|
|
|
|
|
// be honored.
|
|
|
|
|
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
|
|
|
|
|
if len(globPaths) == 0 {
|
|
|
|
|
return errors.New("no dataset specified")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
if s.initBegan {
|
|
|
|
|
// SetDataset already called. All trainer will call
|
|
|
|
|
// SetDataset, but we only handle the first one. Treat
|
|
|
|
|
// other calls as successful but do nothing.
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.initBegan = true
|
|
|
|
|
|
|
|
|
|
func getChunks(globPaths []string) ([]Chunk, error) {
|
|
|
|
|
var chunks []Chunk
|
|
|
|
|
var paths []string
|
|
|
|
|
|
|
|
|
|
for _, s := range globPaths {
|
|
|
|
|
match, err := filepath.Glob(s)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
paths = append(paths, match...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(paths) == 0 {
|
|
|
|
|
return errors.New("no valid datset specified")
|
|
|
|
|
return nil, errors.New("no valid datset specified")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, path := range paths {
|
|
|
|
|
f, err := os.Open(path)
|
|
|
|
|
if err != nil {
|
|
|
|
|
panic(err)
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
index, err := recordio.LoadIndex(f)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
err = f.Close()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
count := index.NumChunks()
|
|
|
|
@ -164,14 +141,41 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return chunks, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// SetDataset sets dataset to dispatch for the master server.
|
|
|
|
|
//
|
|
|
|
|
// SetDataset can be call multiple times. But only the first call will
|
|
|
|
|
// be honored.
|
|
|
|
|
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
|
|
|
|
|
if len(globPaths) == 0 {
|
|
|
|
|
return errors.New("no dataset specified")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
if s.initDone {
|
|
|
|
|
// Already initialized. All trainer will call
|
|
|
|
|
// SetDataset, but we only handle the first one. Treat
|
|
|
|
|
// other calls as successful but do nothing.
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
chunks, err := getChunks(globPaths)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
|
|
|
|
|
|
|
|
|
|
err := s.snapshot()
|
|
|
|
|
err = s.snapshot()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
close(s.ready)
|
|
|
|
|
s.initDone = true
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -193,7 +197,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
|
|
|
|
|
// TODO(helin): client need to retry in this
|
|
|
|
|
// error case. Gotcha: RPC client can't
|
|
|
|
|
// compare returned error with predefined
|
|
|
|
|
// erros like io.EOF. Because interface don't
|
|
|
|
|
// errors like io.EOF. Because interface don't
|
|
|
|
|
// have same dynamic value when in different
|
|
|
|
|
// process.
|
|
|
|
|
return errors.New("no more available task")
|
|
|
|
|