Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into row_conv
commit
a18158673f
@ -0,0 +1,50 @@
|
||||
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
|
||||
file(MAKE_DIRECTORY ${GOPATH})
|
||||
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle")
|
||||
file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH})
|
||||
|
||||
function(GO_LIBRARY NAME BUILD_TYPE)
|
||||
if(BUILD_TYPE STREQUAL "STATIC")
|
||||
set(BUILD_MODE -buildmode=c-archive)
|
||||
set(LIB_NAME "lib${NAME}.a")
|
||||
else()
|
||||
set(BUILD_MODE -buildmode=c-shared)
|
||||
if(APPLE)
|
||||
set(LIB_NAME "lib${NAME}.dylib")
|
||||
else()
|
||||
set(LIB_NAME "lib${NAME}.so")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
|
||||
file(RELATIVE_PATH rel ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
# find Paddle directory.
|
||||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
|
||||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
|
||||
get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY)
|
||||
|
||||
# automatically get all dependencies specified in the source code
|
||||
# for given target.
|
||||
add_custom_target(goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
|
||||
|
||||
# make a symlink that references Paddle inside $GOPATH, so go get
|
||||
# will use the local changes in Paddle rather than checkout Paddle
|
||||
# in github.
|
||||
add_custom_target(copyPaddle
|
||||
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH})
|
||||
add_dependencies(goGet copyPaddle)
|
||||
|
||||
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
|
||||
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
|
||||
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
|
||||
${CMAKE_GO_FLAGS} ${GO_SOURCE}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
|
||||
add_dependencies(${NAME} goGet)
|
||||
|
||||
if(NOT BUILD_TYPE STREQUAL "STATIC")
|
||||
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME} DESTINATION bin)
|
||||
endif()
|
||||
endfunction(GO_LIBRARY)
|
@ -0,0 +1,13 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
|
||||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
|
||||
|
||||
project(cxx_go C Go)
|
||||
|
||||
include(golang)
|
||||
include(flags)
|
||||
|
||||
go_library(client STATIC)
|
||||
add_subdirectory(test)
|
@ -0,0 +1,232 @@
|
||||
package pserver
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/pserver/internal/connection"
|
||||
)
|
||||
|
||||
// TODO(helin): add RPC call retry logic
|
||||
|
||||
// Selector selects if the client should initialize parameter servers.
|
||||
type Selector interface {
|
||||
Select() bool
|
||||
}
|
||||
|
||||
// Server is the identification of a parameter Server.
|
||||
type Server struct {
|
||||
Index int
|
||||
Addr string
|
||||
}
|
||||
|
||||
// Lister lists currently available parameter servers.
|
||||
type Lister interface {
|
||||
List() []Server
|
||||
}
|
||||
|
||||
// Client is the client to parameter servers.
|
||||
type Client struct {
|
||||
sel Selector
|
||||
pservers []*connection.Conn
|
||||
}
|
||||
|
||||
// NewClient creates a new client.
|
||||
func NewClient(l Lister, pserverNum int, sel Selector) *Client {
|
||||
c := &Client{sel: sel}
|
||||
c.pservers = make([]*connection.Conn, pserverNum)
|
||||
for i := 0; i < pserverNum; i++ {
|
||||
c.pservers[i] = connection.New()
|
||||
}
|
||||
go c.monitorPservers(l, pserverNum)
|
||||
return c
|
||||
}
|
||||
|
||||
// monitorPservers monitors pserver addresses, and updates connection
|
||||
// when the address changes.
|
||||
func (c *Client) monitorPservers(l Lister, pserverNum int) {
|
||||
knownServers := make([]Server, pserverNum)
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
monitor := func() {
|
||||
curServers := make([]Server, pserverNum)
|
||||
list := l.List()
|
||||
for _, l := range list {
|
||||
curServers[l.Index] = l
|
||||
}
|
||||
|
||||
for i := range knownServers {
|
||||
if knownServers[i].Addr != curServers[i].Addr {
|
||||
err := c.pservers[i].Connect(curServers[i].Addr)
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
|
||||
// connect to addr failed, set
|
||||
// to last known addr in order
|
||||
// to retry next time.
|
||||
curServers[i].Addr = knownServers[i].Addr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
knownServers = curServers
|
||||
}
|
||||
|
||||
monitor()
|
||||
for _ = range ticker.C {
|
||||
monitor()
|
||||
}
|
||||
}
|
||||
|
||||
// BeginInitParams begins to initialize parameters on parameter
|
||||
// servers.
|
||||
//
|
||||
// BeginInitParams will be called from multiple trainers, only one
|
||||
// trainer will be selected to initialize the parameters on parameter
|
||||
// servers. Other trainers will be blocked until the initialization is
|
||||
// done, and they need to get the initialized parameters from
|
||||
// parameter servers using GetParams.
|
||||
func (c *Client) BeginInitParams() bool {
|
||||
return c.sel.Select()
|
||||
}
|
||||
|
||||
// InitParam initializes the parameter on parameter servers.
|
||||
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
|
||||
var dummy int
|
||||
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy)
|
||||
}
|
||||
|
||||
// FinishInitParams tells parameter servers client has sent all
|
||||
// parameters to parameter servers as initialization.
|
||||
func (c *Client) FinishInitParams() error {
|
||||
for _, p := range c.pservers {
|
||||
var dummy int
|
||||
err := p.Call("Service.FinishInitParams", dummy, &dummy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendGrads sends gradients to parameter servers for updating
|
||||
// parameters.
|
||||
func (c *Client) SendGrads(grads []Gradient) error {
|
||||
errCh := make(chan error, len(grads))
|
||||
for _, g := range grads {
|
||||
go func(g Gradient) {
|
||||
var dummy int
|
||||
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy)
|
||||
errCh <- err
|
||||
}(g)
|
||||
}
|
||||
|
||||
recv := 0
|
||||
for err := range errCh {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recv++
|
||||
if recv == len(grads) {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type result struct {
|
||||
idx int
|
||||
param Parameter
|
||||
err error
|
||||
}
|
||||
|
||||
type results []result
|
||||
|
||||
func (r results) Len() int {
|
||||
return len(r)
|
||||
}
|
||||
|
||||
func (r results) Less(i int, j int) bool {
|
||||
return r[i].idx < r[j].idx
|
||||
}
|
||||
|
||||
func (r results) Swap(i int, j int) {
|
||||
r[i], r[j] = r[j], r[i]
|
||||
}
|
||||
|
||||
// GetParams gets parameters from parameter servers.
|
||||
func (c *Client) GetParams(names []string) ([]Parameter, error) {
|
||||
rCh := make(chan result, len(names))
|
||||
|
||||
for idx, name := range names {
|
||||
go func(name string, idx int) {
|
||||
var parameter Parameter
|
||||
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, ¶meter)
|
||||
rCh <- result{idx: idx, param: parameter, err: err}
|
||||
}(name, idx)
|
||||
}
|
||||
|
||||
var rs results
|
||||
recv := 0
|
||||
for r := range rCh {
|
||||
if r.err != nil {
|
||||
return nil, r.err
|
||||
}
|
||||
rs = append(rs, r)
|
||||
|
||||
recv++
|
||||
if recv == len(names) {
|
||||
break
|
||||
}
|
||||
}
|
||||
sort.Sort(rs)
|
||||
|
||||
ps := make([]Parameter, len(rs))
|
||||
for i := range rs {
|
||||
ps[i] = rs[i].param
|
||||
}
|
||||
|
||||
return ps, nil
|
||||
}
|
||||
|
||||
// Save indicates parameters to save the parameter to the given path.
|
||||
func (c *Client) Save(path string) error {
|
||||
errCh := make(chan error, len(c.pservers))
|
||||
|
||||
for _, p := range c.pservers {
|
||||
var dummy int
|
||||
err := p.Call("Service.Save", path, &dummy)
|
||||
errCh <- err
|
||||
}
|
||||
|
||||
recv := 0
|
||||
for err := range errCh {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recv++
|
||||
if recv == len(c.pservers) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(helin): there will be many files under path, need to
|
||||
// merge them into a single file.
|
||||
return nil
|
||||
}
|
||||
|
||||
func strHash(s string) uint32 {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(s))
|
||||
return h.Sum32()
|
||||
}
|
||||
|
||||
// TODO(helin): now partition only select which parameter server to
|
||||
// send the entire parameter. We need to partition a parameter into
|
||||
// small blocks and send to different parameter servers.
|
||||
func (c *Client) partition(key string) int {
|
||||
return int(strHash(key) % uint32(len(c.pservers)))
|
||||
}
|
@ -0,0 +1,123 @@
|
||||
package pserver_test
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/pserver"
|
||||
)
|
||||
|
||||
const numPserver = 10
|
||||
|
||||
var port [numPserver]int
|
||||
|
||||
func init() {
|
||||
for i := 0; i < numPserver; i++ {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
port[i] = p
|
||||
|
||||
go func(l net.Listener) {
|
||||
s := pserver.NewService()
|
||||
server := rpc.NewServer()
|
||||
err := server.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
err = http.Serve(l, mux)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(l)
|
||||
}
|
||||
}
|
||||
|
||||
type selector bool
|
||||
|
||||
func (s selector) Select() bool {
|
||||
return bool(s)
|
||||
}
|
||||
|
||||
type lister []pserver.Server
|
||||
|
||||
func (l lister) List() []pserver.Server {
|
||||
return l
|
||||
}
|
||||
|
||||
func TestClientFull(t *testing.T) {
|
||||
servers := make([]pserver.Server, numPserver)
|
||||
for i := 0; i < numPserver; i++ {
|
||||
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
|
||||
}
|
||||
c := pserver.NewClient(lister(servers), len(servers), selector(true))
|
||||
selected := c.BeginInitParams()
|
||||
if !selected {
|
||||
t.Fatal("should be selected.")
|
||||
}
|
||||
|
||||
const numParameter = 100
|
||||
for i := 0; i < numParameter; i++ {
|
||||
var p pserver.Parameter
|
||||
p.Name = "p_" + strconv.Itoa(i)
|
||||
p.ElementType = pserver.Float32
|
||||
p.Content = make([]byte, (i+1)*100)
|
||||
err := c.InitParam(pserver.ParameterWithConfig{Param: p})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := c.FinishInitParams()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var grads []pserver.Gradient
|
||||
for i := 0; i < numParameter/2; i++ {
|
||||
var g pserver.Gradient
|
||||
g.Name = "p_" + strconv.Itoa(i)
|
||||
g.ElementType = pserver.Float32
|
||||
g.Content = make([]byte, (i+1)*100)
|
||||
grads = append(grads, g)
|
||||
}
|
||||
|
||||
err = c.SendGrads(grads)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
names := make([]string, numParameter)
|
||||
for i := 0; i < numParameter; i++ {
|
||||
names[i] = "p_" + strconv.Itoa(i)
|
||||
}
|
||||
|
||||
params, err := c.GetParams(names)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(names) != len(params) {
|
||||
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
|
||||
}
|
||||
|
||||
for i := range params {
|
||||
if names[i] != params[i].Name {
|
||||
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i])
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,84 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// TODO(helin): add TCP re-connect logic
|
||||
|
||||
// Conn is a connection to a parameter server
|
||||
type Conn struct {
|
||||
mu sync.Mutex
|
||||
client *rpc.Client
|
||||
waitConn chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new connection.
|
||||
func New() *Conn {
|
||||
c := &Conn{}
|
||||
return c
|
||||
}
|
||||
|
||||
// Connect connects the connection to a address.
|
||||
func (c *Conn) Connect(addr string) error {
|
||||
c.mu.Lock()
|
||||
if c.client != nil {
|
||||
err := c.client.Close()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
c.client = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
client, err := rpc.DialHTTP("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
c.client = client
|
||||
if c.waitConn != nil {
|
||||
close(c.waitConn)
|
||||
c.waitConn = nil
|
||||
}
|
||||
} else {
|
||||
return errors.New("client already set from a concurrent goroutine")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Call make a RPC call.
|
||||
//
|
||||
// Call will be blocked until the connection to remote RPC service
|
||||
// being established.
|
||||
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
|
||||
c.mu.Lock()
|
||||
client := c.client
|
||||
var waitCh chan struct{}
|
||||
if client == nil {
|
||||
if c.waitConn != nil {
|
||||
waitCh = c.waitConn
|
||||
} else {
|
||||
waitCh = make(chan struct{})
|
||||
c.waitConn = waitCh
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if waitCh != nil {
|
||||
// wait until new connection being established
|
||||
<-waitCh
|
||||
return c.Call(serviceMethod, args, reply)
|
||||
}
|
||||
|
||||
return client.Call(serviceMethod, args, reply)
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue