You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/go/pserver/service.go

225 lines
5.1 KiB

package pserver
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"errors"
"fmt"
"os"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// ElementType is the type of elements of a Parameter.
type ElementType int
const (
AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized"
)
const (
checkpoint_path = "/checkpoints/"
)
// Supported element types
const (
Int32 ElementType = iota
UInt32
Int64
UInt64
Float32
Float64
)
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"
// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
Name string
ElementType ElementType
Content []byte
}
// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
Param Parameter
Config []byte // parameter configuration in Proto Buffer format
State []byte // parameter training state
}
// Gradient is the gradient of the parameter.
type Gradient Parameter
// Service is the RPC service for pserver.
type Service struct {
initialized chan struct{}
idx int
mu sync.Mutex
optMap map[string]*optimizer
}
type Checkpoint struct {
uuid string
md5sum string
timestamp string
}
//serialize ParameterWithConfig to byte stream
func GetBytes(content ...interface{}) ([]byte, error) {
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(content)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(idx int) (*Service, error) {
s := &Service{
idx: idx,
}
s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{})
return s, nil
}
// InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select {
case <-s.initialized:
return errors.New(AlreadyInitialized)
default:
}
// TODO(helin): parse parameter config
s.mu.Lock()
defer s.mu.Unlock()
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
return nil
}
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
select {
case <-s.initialized:
return errors.New(AlreadyInitialized)
default:
}
close(s.initialized)
return nil
}
// SendGrad sends gradient to parameter servers for parameter
// optimization.
func (s *Service) SendGrad(g Gradient, dummy *int) error {
select {
case <-s.initialized:
default:
return errors.New(Uninitialized)
}
s.mu.Lock()
defer s.mu.Unlock()
o, ok := s.optMap[g.Name]
if !ok {
return fmt.Errorf("parameter: %s does not exist", g.Name)
}
return o.UpdateParameter(g)
}
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
opt, ok := s.optMap[name]
if !ok {
return fmt.Errorf("parameter: %s does not exist", name)
}
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
// learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
return nil
}
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
//FIXME: checkpoint is only used by pserver
// and has a constant path of */checkpoints/{pserver_idx}*
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
var paramWithConfig ParameterWithConfig
for name, opt := range s.optMap {
paramWithConfig.Param.Name = name
paramWithConfig.Param.ElementType = opt.elementType
paramWithConfig.Param.Content = opt.GetWeights()
paramWithConfig.State = opt.GetStates()
content, err := GetBytes(paramWithConfig)
if err != nil {
log.Errorln(err)
}
ck := Checkpoint{}
h := md5.New()
ck.md5sum = hex.EncodeToString(h.Sum(content))
ck.timestamp = time.Now().String()
ck.uuid = checkpoint_path + strconv.Itoa(s.idx)
ckbytes, err := GetBytes(ck)
if err != nil {
log.Errorln(err)
}
// TODO: according design doc, need to save uuid to etcd in json format
// {\"uuid\": [UUID], \"md5\", \"MD5 sum\", \"timestamp\": xxxx}
log.Infof("parameter checkpoint %s", ckbytes)
if _, err = os.Stat(ck.uuid); os.IsNotExist(err) {
log.Info("checkpoint not exists.")
} else {
err = os.Remove(ck.uuid)
log.Infof("remove %s", ck.uuid)
}
f, err := os.Create(ck.uuid)
defer f.Close()
if err != nil {
log.Errorln(err)
}
writer := bufio.NewWriter(f)
_, err = writer.Write(content)
if err != nil {
log.Errorln(err)
}
}
return nil
}