parent
7e93921a79
commit
0babf84b0f
@ -0,0 +1,123 @@
|
||||
package pserver_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
|
||||
)
|
||||
|
||||
const numPserver = 10
|
||||
|
||||
var port [numPserver]int
|
||||
|
||||
func init() {
|
||||
for i := 0; i < numPserver; i++ {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
port[i] = p
|
||||
|
||||
go func(l net.Listener) {
|
||||
s := pserver.NewService()
|
||||
server := rpc.NewServer()
|
||||
err := server.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
err = http.Serve(l, mux)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(l)
|
||||
}
|
||||
}
|
||||
|
||||
type selector bool
|
||||
|
||||
func (s selector) Select() bool {
|
||||
return bool(s)
|
||||
}
|
||||
|
||||
type lister []pserver.Server
|
||||
|
||||
func (l lister) List() []pserver.Server {
|
||||
return l
|
||||
}
|
||||
|
||||
func TestClientFull(t *testing.T) {
|
||||
servers := make([]pserver.Server, numPserver)
|
||||
for i := 0; i < numPserver; i++ {
|
||||
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
|
||||
}
|
||||
c := pserver.NewClient(lister(servers), len(servers), selector(true))
|
||||
selected := c.BeginInitParams()
|
||||
if !selected {
|
||||
t.Fatal("should be selected.")
|
||||
}
|
||||
|
||||
const numParameter = 100
|
||||
for i := 0; i < numParameter; i++ {
|
||||
var p pserver.Parameter
|
||||
p.Name = "p_" + strconv.Itoa(i)
|
||||
p.ElementType = pserver.Float32
|
||||
p.Content = make([]byte, (i+1)*100)
|
||||
err := c.InitParam(pserver.ParameterWithConfig{Param: p})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := c.FinishInitParams()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var grads []pserver.Gradient
|
||||
for i := 0; i < numParameter/2; i++ {
|
||||
var g pserver.Gradient
|
||||
g.Name = "p_" + strconv.Itoa(i)
|
||||
g.ElementType = pserver.Float32
|
||||
g.Content = make([]byte, (i+1)*100)
|
||||
grads = append(grads, g)
|
||||
}
|
||||
|
||||
err = c.SendGrads(grads)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
names := make([]string, numParameter)
|
||||
for i := 0; i < numParameter; i++ {
|
||||
names[i] = "p_" + strconv.Itoa(i)
|
||||
}
|
||||
|
||||
params, err := c.GetParams(names)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(names) != len(params) {
|
||||
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
|
||||
}
|
||||
|
||||
for i := range params {
|
||||
if names[i] != params[i].Name {
|
||||
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i])
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TODO(helin): add TCP re-connect logic
|
||||
|
||||
// Conn is a connection to a parameter server
|
||||
type Conn struct {
|
||||
mu sync.Mutex
|
||||
client *rpc.Client
|
||||
waitConn chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new connection.
|
||||
func New() *Conn {
|
||||
c := &Conn{}
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect connects the connection to a address.
|
||||
func (c *Conn) Connect(addr string) error {
|
||||
c.mu.Lock()
|
||||
if c.client != nil {
|
||||
err := c.client.Close()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
c.client = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
client, err := rpc.DialHTTP("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
c.client = client
|
||||
if c.waitConn != nil {
|
||||
close(c.waitConn)
|
||||
c.waitConn = nil
|
||||
}
|
||||
} else {
|
||||
return errors.New("client already set from a concurrent goroutine")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call make a RPC call.
|
||||
//
|
||||
// Call will be blocked until the connection to remote RPC service
|
||||
// being established.
|
||||
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
|
||||
c.mu.Lock()
|
||||
client := c.client
|
||||
var waitCh chan struct{}
|
||||
if client == nil {
|
||||
if c.waitConn != nil {
|
||||
waitCh = c.waitConn
|
||||
} else {
|
||||
waitCh = make(chan struct{})
|
||||
c.waitConn = waitCh
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if waitCh != nil {
|
||||
// wait until new connection being established
|
||||
<-waitCh
|
||||
return c.Call(serviceMethod, args, reply)
|
||||
}
|
||||
|
||||
return client.Call(serviceMethod, args, reply)
|
||||
}
|
@ -0,0 +1,10 @@
|
||||
package pserver
|
||||
|
||||
type partitioner struct {
|
||||
shardNum int
|
||||
}
|
||||
|
||||
// partitioner partitions the parameters into shards.
|
||||
func newPartitioner(shardNum int) *partitioner {
|
||||
return &partitioner{shardNum: shardNum}
|
||||
}
|
Loading…
Reference in new issue