You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
94 lines
1.9 KiB
94 lines
1.9 KiB
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/rpc"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/namsral/flag"
|
|
|
|
"github.com/PaddlePaddle/Paddle/go/master"
|
|
"github.com/PaddlePaddle/Paddle/go/recordio"
|
|
)
|
|
|
|
func main() {
|
|
port := flag.Int("port", 8080, "port of the master server.")
|
|
dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")
|
|
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
|
|
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
|
|
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
|
|
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
|
|
flag.Parse()
|
|
|
|
if *dataset == "" {
|
|
panic("no dataset specified.")
|
|
}
|
|
|
|
if *faultTolerance {
|
|
panic("fault tolernance not implemented.")
|
|
}
|
|
|
|
var chunks []master.Chunk
|
|
var paths []string
|
|
ss := strings.Split(*dataset, ",")
|
|
fmt.Println(ss)
|
|
for _, s := range ss {
|
|
match, err := filepath.Glob(s)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
paths = append(paths, match...)
|
|
}
|
|
|
|
if len(paths) == 0 {
|
|
panic("no valid datset specified.")
|
|
}
|
|
|
|
idx := 0
|
|
for _, path := range paths {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
index, err := recordio.LoadIndex(f)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
f.Close()
|
|
|
|
count := index.NumChunks()
|
|
for i := 0; i < count; i++ {
|
|
chunk := master.Chunk{
|
|
Idx: idx,
|
|
Path: path,
|
|
Index: *index.ChunkIndex(i),
|
|
}
|
|
chunks = append(chunks, chunk)
|
|
}
|
|
}
|
|
|
|
s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
|
|
err := rpc.Register(s)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
rpc.HandleHTTP()
|
|
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
err = http.Serve(l, nil)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}
|