Merge pull request #2188 from helinwang/pserver
Implement Pserver RPC, gradient update logic, cgo partrefactor_docs
commit
519555e399
@ -0,0 +1 @@
|
|||||||
|
pserver
|
@ -0,0 +1,33 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/rpc"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
port := flag.Int("p", 0, "port of the pserver")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
s := pserver.NewService()
|
||||||
|
err := rpc.Register(s)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rpc.HandleHTTP()
|
||||||
|
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = http.Serve(l, nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,52 @@
|
|||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
|
#include "optimizer.h"
|
||||||
|
|
||||||
|
typedef int (*update_func)(void*, void*, paddle_element_type, const void*, int);
|
||||||
|
typedef void (*release_func)(void*);
|
||||||
|
|
||||||
|
typedef struct paddle_optimizer {
|
||||||
|
update_func update;
|
||||||
|
release_func release;
|
||||||
|
void* optimizer;
|
||||||
|
} paddle_optimizer;
|
||||||
|
|
||||||
|
void paddle_release_optimizer(paddle_optimizer* o) {
|
||||||
|
o->release(o->optimizer);
|
||||||
|
free(o);
|
||||||
|
}
|
||||||
|
|
||||||
|
int paddle_update_parameter(paddle_optimizer* o,
|
||||||
|
void* buffer,
|
||||||
|
paddle_element_type element_type,
|
||||||
|
const void* gradient,
|
||||||
|
int num_bytes) {
|
||||||
|
return o->update(o->optimizer, buffer, element_type, gradient, num_bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct { double learning_rate; } SGD_optimizer;
|
||||||
|
|
||||||
|
int update_SGD(void* optimizer,
|
||||||
|
void* buffer,
|
||||||
|
paddle_element_type element_type,
|
||||||
|
const void* gradient,
|
||||||
|
int num_bytes) {
|
||||||
|
SGD_optimizer* o = (SGD_optimizer*)optimizer;
|
||||||
|
// TODO
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void release_SGD(void* optimizer) {
|
||||||
|
SGD_optimizer* o = (SGD_optimizer*)optimizer;
|
||||||
|
// nothing allocated on heap
|
||||||
|
}
|
||||||
|
|
||||||
|
paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate) {
|
||||||
|
SGD_optimizer* impl = (SGD_optimizer*)malloc(sizeof(SGD_optimizer));
|
||||||
|
impl->learning_rate = learning_rate;
|
||||||
|
paddle_optimizer* opt = (paddle_optimizer*)malloc(sizeof(paddle_optimizer));
|
||||||
|
opt->update = update_SGD;
|
||||||
|
opt->release = release_SGD;
|
||||||
|
opt->optimizer = impl;
|
||||||
|
return opt;
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
package pserver
|
||||||
|
|
||||||
|
/*
|
||||||
|
#include "optimizer.h"
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type optimizerType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
sgd optimizerType = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
var nullPtr = unsafe.Pointer(uintptr(0))
|
||||||
|
|
||||||
|
type optimizer struct {
|
||||||
|
opt *C.struct_paddle_optimizer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOptimizer(t optimizerType, learning_rate float64) *optimizer {
|
||||||
|
o := &optimizer{}
|
||||||
|
o.opt = C.paddle_create_SGD_optimizer(C.double(learning_rate))
|
||||||
|
return o
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error {
|
||||||
|
if len(p.Content) != len(g.Content) {
|
||||||
|
return fmt.Errorf("parameter and gradient length not match, parameter: %d, gradient: %d", len(p.Content), len(g.Content))
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.ElementType != g.ElementType {
|
||||||
|
return fmt.Errorf("parameter and gradient element type not match, parameter: %v, gradient: %v", p.ElementType, g.ElementType)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
|
||||||
|
if r != 0 {
|
||||||
|
return fmt.Errorf("optimizer update returned error code: %d", r)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *optimizer) Cleanup() {
|
||||||
|
if unsafe.Pointer(o.opt) != nullPtr {
|
||||||
|
C.paddle_release_optimizer(o.opt)
|
||||||
|
o.opt = (*C.struct_paddle_optimizer)(nullPtr)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,22 @@
|
|||||||
|
#ifndef PADDLE_PSERVER_OPTIMIZER_H
|
||||||
|
#define PADDLE_PSERVER_OPTIMIZER_H
|
||||||
|
|
||||||
|
typedef enum {
|
||||||
|
PADDLE_ELEMENT_TYPE_INT32 = 0,
|
||||||
|
PADDLE_ELEMENT_TYPE_UINT32 = 1,
|
||||||
|
PADDLE_ELEMENT_TYPE_INT64 = 2,
|
||||||
|
PADDLE_ELEMENT_TYPE_UINT64 = 3,
|
||||||
|
PADDLE_ELEMENT_TYPE_FLOAT32 = 4,
|
||||||
|
PADDLE_ELEMENT_TYPE_FLOAT64 = 5,
|
||||||
|
} paddle_element_type;
|
||||||
|
|
||||||
|
struct paddle_optimizer;
|
||||||
|
struct paddle_optimizer* paddle_create_SGD_optimizer(double learning_rate);
|
||||||
|
void paddle_release_optimizer(struct paddle_optimizer* o);
|
||||||
|
int paddle_update_parameter(struct paddle_optimizer* o,
|
||||||
|
void* buffer,
|
||||||
|
paddle_element_type element_type,
|
||||||
|
const void* gradient,
|
||||||
|
int num_bytes);
|
||||||
|
|
||||||
|
#endif /* PADDLE_PSERVER_OPTIMIZER_H */
|
@ -0,0 +1,8 @@
|
|||||||
|
package pserver
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSGDCreateRelease(t *testing.T) {
|
||||||
|
o := newOptimizer(sgd, 1)
|
||||||
|
o.Cleanup()
|
||||||
|
}
|
@ -0,0 +1,190 @@
|
|||||||
|
package pserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ElementType is the type of elements of a Parameter.
|
||||||
|
type ElementType int
|
||||||
|
|
||||||
|
var ErrAlreadyInitialized = errors.New("pserver already initialized")
|
||||||
|
var ErrUninitialized = errors.New("pserver not fully initialized")
|
||||||
|
|
||||||
|
// Supported element types
|
||||||
|
const (
|
||||||
|
Int32 ElementType = iota
|
||||||
|
UInt32
|
||||||
|
Int64
|
||||||
|
UInt64
|
||||||
|
Float32
|
||||||
|
Float64
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parameter is a piece of data to sync with the parameter server.
|
||||||
|
type Parameter struct {
|
||||||
|
Name string
|
||||||
|
ElementType ElementType
|
||||||
|
Content []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParameterWithConfig contains the parameter and the configuration.
|
||||||
|
type ParameterWithConfig struct {
|
||||||
|
Param Parameter
|
||||||
|
Config []byte // parameter configuration in Proto Buffer format
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gradient is the gradient of the parameter.
|
||||||
|
type Gradient Parameter
|
||||||
|
|
||||||
|
// Service is the RPC service for pserver.
|
||||||
|
type Service struct {
|
||||||
|
initialized chan struct{}
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
opt *optimizer
|
||||||
|
paramMap map[string]Parameter
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewService creates a new service.
|
||||||
|
func NewService() *Service {
|
||||||
|
s := &Service{}
|
||||||
|
s.paramMap = make(map[string]Parameter)
|
||||||
|
s.initialized = make(chan struct{})
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginInitParams tells the parameter server that the parameter
|
||||||
|
// initialization has begun.
|
||||||
|
func (s *Service) BeginInitParams(config []byte, dummy *int) error {
|
||||||
|
select {
|
||||||
|
case <-s.initialized:
|
||||||
|
return ErrAlreadyInitialized
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.opt != nil {
|
||||||
|
s.opt.Cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(helin): parse learning rate from config
|
||||||
|
s.opt = newOptimizer(sgd, 0.01)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitParam initializes a parameter.
|
||||||
|
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
|
||||||
|
select {
|
||||||
|
case <-s.initialized:
|
||||||
|
return ErrAlreadyInitialized
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(helin): parse parameter config
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// TODO(helin): check if paramWithConfigs.Param.Content is
|
||||||
|
// properly memory aligned, if not, make copy to a memory
|
||||||
|
// aligned region.
|
||||||
|
s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FinishInitParams tells the parameter server that the parameter
|
||||||
|
// initialization has finished.
|
||||||
|
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
|
||||||
|
select {
|
||||||
|
case <-s.initialized:
|
||||||
|
return ErrAlreadyInitialized
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
close(s.initialized)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendGrads sends gradients to parameter servers for parameter
|
||||||
|
// optimization.
|
||||||
|
func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
|
||||||
|
select {
|
||||||
|
case <-s.initialized:
|
||||||
|
default:
|
||||||
|
return ErrUninitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
count := len(grads)
|
||||||
|
if count == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, g := range grads {
|
||||||
|
if _, ok := s.paramMap[g.Name]; !ok {
|
||||||
|
return fmt.Errorf("parameter: %s does not exist", g.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
errCh := make(chan error, count)
|
||||||
|
for _, g := range grads {
|
||||||
|
go func(p Parameter, g Gradient) {
|
||||||
|
err := s.opt.UpdateParameter(p, g)
|
||||||
|
errCh <- err
|
||||||
|
}(s.paramMap[g.Name], g)
|
||||||
|
}
|
||||||
|
|
||||||
|
recv := 0
|
||||||
|
for err := range errCh {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
recv++
|
||||||
|
if recv == count {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParams gets parameters from the parameter server.
|
||||||
|
func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
|
||||||
|
<-s.initialized
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
for _, n := range names {
|
||||||
|
if _, ok := s.paramMap[n]; !ok {
|
||||||
|
return fmt.Errorf("parameter: %s does not exist", n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*parameters = make([]Parameter, len(names))
|
||||||
|
for i, n := range names {
|
||||||
|
// The parameter content (a byte slice) may change
|
||||||
|
// during RPC serialization due to write from other
|
||||||
|
// goroutine, we allow it since mini-batch based deep
|
||||||
|
// learning optimization methods are stochastic in
|
||||||
|
// nature. This race condition is allowed deliberately
|
||||||
|
// to save the program from making a copy of the
|
||||||
|
// paramter content.
|
||||||
|
(*parameters)[i] = s.paramMap[n]
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save tells the parameter server to save parameters.
|
||||||
|
func (s *Service) Save(path string, dummy *int) error {
|
||||||
|
<-s.initialized
|
||||||
|
|
||||||
|
// TODO
|
||||||
|
return nil
|
||||||
|
}
|
@ -0,0 +1,165 @@
|
|||||||
|
package pserver_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFull(t *testing.T) {
|
||||||
|
s := pserver.NewService()
|
||||||
|
var dummy int
|
||||||
|
err := s.BeginInitParams(nil, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
var p pserver.Parameter
|
||||||
|
p.Name = "param_a"
|
||||||
|
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
|
||||||
|
p.ElementType = pserver.Int32
|
||||||
|
err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
var p1 pserver.Parameter
|
||||||
|
p1.Name = "param_b"
|
||||||
|
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||||
|
p1.ElementType = pserver.Float32
|
||||||
|
err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.FinishInitParams(0, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
var params []pserver.Parameter
|
||||||
|
err = s.GetParams([]string{"param_b", "param_a"}, ¶ms)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)}
|
||||||
|
err = s.SendGrads(grads, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
var params1 []pserver.Parameter
|
||||||
|
err = s.GetParams([]string{"param_b", "param_a"}, ¶ms1)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(params) != 2 {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't compare content, since it's already changed by
|
||||||
|
// gradient update.
|
||||||
|
params1[0].Content = nil
|
||||||
|
params1[0].Content = nil
|
||||||
|
p.Content = nil
|
||||||
|
p1.Content = nil
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMultipleInit(t *testing.T) {
|
||||||
|
s := pserver.NewService()
|
||||||
|
var dummy int
|
||||||
|
err := s.BeginInitParams(nil, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// this is fine, it's possible for client to call init
|
||||||
|
// multiple times.
|
||||||
|
err = s.BeginInitParams(nil, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.FinishInitParams(0, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.FinishInitParams(0, &dummy)
|
||||||
|
if err != pserver.ErrAlreadyInitialized {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.BeginInitParams(nil, &dummy)
|
||||||
|
if err != pserver.ErrAlreadyInitialized {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUninitialized(t *testing.T) {
|
||||||
|
s := pserver.NewService()
|
||||||
|
var dummy int
|
||||||
|
err := s.SendGrads(nil, &dummy)
|
||||||
|
if err != pserver.ErrUninitialized {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlockUntilInitialized(t *testing.T) {
|
||||||
|
s := pserver.NewService()
|
||||||
|
ch := make(chan struct{}, 2)
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
var params []pserver.Parameter
|
||||||
|
err := s.GetParams(nil, ¶ms)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
ch <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
var dummy int
|
||||||
|
err := s.Save("", &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
wg.Done()
|
||||||
|
ch <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var dummy int
|
||||||
|
err := s.BeginInitParams(nil, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ch:
|
||||||
|
// some function returned before initialization is completed.
|
||||||
|
t.FailNow()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.FinishInitParams(0, &dummy)
|
||||||
|
if err != nil {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
Loading…
Reference in new issue