Merge pull request #2429 from helinwang/master_client
implement master server client, remove unnecessary dummy variablegangliao-patch-1
commit
5f5e128d29
@ -0,0 +1,82 @@
|
||||
package master
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/connection"
|
||||
)
|
||||
|
||||
// Addresser provide the address of the master server.
|
||||
type Addresser interface {
|
||||
Address() string
|
||||
}
|
||||
|
||||
// Client is the client of the master server.
|
||||
type Client struct {
|
||||
conn *connection.Conn
|
||||
}
|
||||
|
||||
// NewClient creates a new Client.
|
||||
func NewClient(addr Addresser) *Client {
|
||||
c := &Client{}
|
||||
c.conn = connection.New()
|
||||
go c.monitorMaster(addr)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) monitorMaster(addr Addresser) {
|
||||
lastMaster := ""
|
||||
monitor := func() {
|
||||
// get the lastest address of the master server,
|
||||
// connect to the new address once address changed.
|
||||
curMaster := addr.Address()
|
||||
if curMaster != lastMaster {
|
||||
if curMaster == "" {
|
||||
err := c.conn.Close()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
} else {
|
||||
err := c.conn.Connect(curMaster)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
// connect to addr failed, set
|
||||
// to last known addr in order
|
||||
// to retry next time.
|
||||
curMaster = lastMaster
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
lastMaster = curMaster
|
||||
}
|
||||
|
||||
monitor()
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
for _ = range ticker.C {
|
||||
monitor()
|
||||
}
|
||||
}
|
||||
|
||||
// SetDataset set dataset for the master server to dispatch.
|
||||
//
|
||||
// SetDataset can be call multiple times from different nodes. But
|
||||
// only the first call will be honored.
|
||||
func (c *Client) SetDataset(globPaths []string) error {
|
||||
return c.conn.Call("Service.SetDataset", globPaths, nil)
|
||||
}
|
||||
|
||||
// GetTask gets a new task from the master server.
|
||||
func (c *Client) GetTask() (Task, error) {
|
||||
var t Task
|
||||
err := c.conn.Call("Service.GetTask", 0, &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
// TaskFinished tells the master server a task is finished.
|
||||
func (c *Client) TaskFinished(taskID int) error {
|
||||
return c.conn.Call("Service.TaskFinished", taskID, nil)
|
||||
}
|
@ -0,0 +1,120 @@
|
||||
package master_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
"github.com/PaddlePaddle/recordio"
|
||||
)
|
||||
|
||||
const (
|
||||
totalTask = 20
|
||||
chunkPerTask = 10
|
||||
)
|
||||
|
||||
var port int
|
||||
|
||||
func init() {
|
||||
log.SetLevel(log.ErrorLevel)
|
||||
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
port = p
|
||||
|
||||
go func(l net.Listener) {
|
||||
s := master.NewService(chunkPerTask, time.Second, 1)
|
||||
server := rpc.NewServer()
|
||||
err := server.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
err = http.Serve(l, mux)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(l)
|
||||
}
|
||||
|
||||
type addresser string
|
||||
|
||||
func (a addresser) Address() string {
|
||||
return string(a)
|
||||
}
|
||||
|
||||
func TestClientFull(t *testing.T) {
|
||||
const p = "/tmp/master_client_test_0"
|
||||
f, err := os.Create(p)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i := 0; i < totalTask*chunkPerTask; i++ {
|
||||
w := recordio.NewWriter(f, -1, -1)
|
||||
w.Write(nil)
|
||||
// call Close to force RecordIO writing a chunk.
|
||||
w.Close()
|
||||
}
|
||||
f.Close()
|
||||
|
||||
c := master.NewClient(addresser(fmt.Sprintf(":%d", port)))
|
||||
c.SetDataset([]string{p})
|
||||
|
||||
checkOnePass := func(i int) {
|
||||
var tasks []master.Task
|
||||
for i := 0; i < totalTask; i++ {
|
||||
task, err := c.GetTask()
|
||||
if err != nil {
|
||||
t.Fatal(i, err)
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
_, err = c.GetTask()
|
||||
if err == nil {
|
||||
t.Fatal(i, "should get error.")
|
||||
}
|
||||
|
||||
err = c.TaskFinished(tasks[0].ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tasks = tasks[1:]
|
||||
task, err := c.GetTask()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
|
||||
for _, task := range tasks {
|
||||
err = c.TaskFinished(task.ID)
|
||||
if err != nil {
|
||||
t.Fatal(i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
checkOnePass(i)
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue