parent
f05649afb7
commit
72a73ab6d2
@ -0,0 +1,74 @@
|
||||
package master
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/connection"
|
||||
)
|
||||
|
||||
// Addresser provide the address of the master server.
|
||||
type Addresser interface {
|
||||
Address() string
|
||||
}
|
||||
|
||||
// Client is the client of the master server.
|
||||
type Client struct {
|
||||
conn *connection.Conn
|
||||
}
|
||||
|
||||
// NewClient creates a new Client.
|
||||
func NewClient(addr Addresser) *Client {
|
||||
c := &Client{}
|
||||
c.conn = connection.New()
|
||||
go c.monitorMaster(addr)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Client) monitorMaster(addr Addresser) {
|
||||
lastMaster := ""
|
||||
monitor := func() {
|
||||
curMaster := addr.Address()
|
||||
if curMaster != lastMaster {
|
||||
if curMaster == "" {
|
||||
err := c.conn.Close()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
} else {
|
||||
err := c.conn.Connect(curMaster)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
// connect to addr failed, set
|
||||
// to last known addr in order
|
||||
// to retry next time.
|
||||
curMaster = lastMaster
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
lastMaster = curMaster
|
||||
}
|
||||
|
||||
monitor()
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
for _ = range ticker.C {
|
||||
monitor()
|
||||
}
|
||||
}
|
||||
|
||||
// GetTask gets a new task from the master server.
|
||||
func (c *Client) GetTask() (Task, error) {
|
||||
var dummy int
|
||||
var t Task
|
||||
err := c.conn.Call("Service.GetTask", dummy, &t)
|
||||
return t, err
|
||||
}
|
||||
|
||||
// TaskFinished tells the master server a task is finished.
|
||||
func (c *Client) TaskFinished(taskID int) error {
|
||||
var dummy int
|
||||
return c.conn.Call("Service.TaskFinished", taskID, &dummy)
|
||||
}
|
@ -0,0 +1,78 @@
|
||||
package master_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
)
|
||||
|
||||
const (
|
||||
totalTask = 20
|
||||
chunkPerTask = 10
|
||||
)
|
||||
|
||||
var port int
|
||||
|
||||
func init() {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
port = p
|
||||
|
||||
go func(l net.Listener) {
|
||||
chunks := make([]master.Chunk, totalTask)
|
||||
s := master.NewService(chunks, chunkPerTask, time.Second, 1)
|
||||
server := rpc.NewServer()
|
||||
err := server.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
err = http.Serve(l, mux)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(l)
|
||||
}
|
||||
|
||||
type addresser string
|
||||
|
||||
func (a addresser) Address() string {
|
||||
return string(a)
|
||||
}
|
||||
|
||||
func TestClientFull(t *testing.T) {
|
||||
c := master.NewClient(addresser(fmt.Sprintf(":%d", port)))
|
||||
|
||||
for i := 0; i < 5*totalTask/chunkPerTask; i++ {
|
||||
task, err := c.GetTask()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if len(task.Chunks) != chunkPerTask {
|
||||
t.Fatal("wrong number of chunk per task", len(task.Chunks))
|
||||
}
|
||||
|
||||
err = c.TaskFinished(task.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in new issue