|
|
|
@ -5,10 +5,11 @@ import (
|
|
|
|
|
"bytes"
|
|
|
|
|
"crypto/md5"
|
|
|
|
|
"encoding/gob"
|
|
|
|
|
"encoding/hex"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"os"
|
|
|
|
|
"path/filepath"
|
|
|
|
|
"strconv"
|
|
|
|
|
"sync"
|
|
|
|
|
"time"
|
|
|
|
@ -26,10 +27,6 @@ const (
|
|
|
|
|
Uninitialized = "pserver not fully initialized"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
checkpoint_path = "./checkpoints/"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Supported element types
|
|
|
|
|
const (
|
|
|
|
|
Int32 ElementType = iota
|
|
|
|
@ -51,49 +48,68 @@ type Parameter struct {
|
|
|
|
|
type ParameterWithConfig struct {
|
|
|
|
|
Param Parameter
|
|
|
|
|
Config []byte // parameter configuration in Proto Buffer format
|
|
|
|
|
State []byte // parameter training state
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Checkpoint of Parameter and State
|
|
|
|
|
type parameterCheckPoint struct {
|
|
|
|
|
ParamConfig ParameterWithConfig
|
|
|
|
|
State []byte
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// checkpoint signature
|
|
|
|
|
type checkpointMeta struct {
|
|
|
|
|
UUID string `json:"uuid"`
|
|
|
|
|
Md5sum string `json:"md5sum"`
|
|
|
|
|
Timestamp string `json:"timestamp"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Checkpoint is the pserver shard persist in file
|
|
|
|
|
type Checkpoint []parameterCheckPoint
|
|
|
|
|
|
|
|
|
|
// Gradient is the gradient of the parameter.
|
|
|
|
|
type Gradient Parameter
|
|
|
|
|
|
|
|
|
|
// Service is the RPC service for pserver.
|
|
|
|
|
type Service struct {
|
|
|
|
|
initialized chan struct{}
|
|
|
|
|
idx int
|
|
|
|
|
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
optMap map[string]*optimizer
|
|
|
|
|
initialized chan struct{}
|
|
|
|
|
idx int
|
|
|
|
|
checkpointInterval int
|
|
|
|
|
checkpointPath string
|
|
|
|
|
client *EtcdClient
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
optMap map[string]*optimizer
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type checkpoint struct {
|
|
|
|
|
Uuid string
|
|
|
|
|
Md5sum string
|
|
|
|
|
Timestamp string
|
|
|
|
|
}
|
|
|
|
|
// //serialize ParameterWithConfig to byte stream
|
|
|
|
|
// func GetBytes(content ...interface{}) ([]byte, error) {
|
|
|
|
|
|
|
|
|
|
//serialize ParameterWithConfig to byte stream
|
|
|
|
|
func GetBytes(content ...interface{}) ([]byte, error) {
|
|
|
|
|
|
|
|
|
|
var buf bytes.Buffer
|
|
|
|
|
encoder := gob.NewEncoder(&buf)
|
|
|
|
|
err := encoder.Encode(content)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return buf.Bytes(), nil
|
|
|
|
|
}
|
|
|
|
|
// var buf bytes.Buffer
|
|
|
|
|
// encoder := gob.NewEncoder(&buf)
|
|
|
|
|
// err := encoder.Encode(content)
|
|
|
|
|
// if err != nil {
|
|
|
|
|
// return nil, err
|
|
|
|
|
// }
|
|
|
|
|
// return buf.Bytes(), nil
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
// NewService creates a new service, will bypass etcd registration if no
|
|
|
|
|
// endpoints specified.
|
|
|
|
|
func NewService(idx int) (*Service, error) {
|
|
|
|
|
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
|
|
|
|
|
s := &Service{
|
|
|
|
|
idx: idx,
|
|
|
|
|
idx: idx,
|
|
|
|
|
checkpointInterval: time.Second * time.Duration(seconds),
|
|
|
|
|
checkpointPath: path,
|
|
|
|
|
client: client,
|
|
|
|
|
}
|
|
|
|
|
s.optMap = make(map[string]*optimizer)
|
|
|
|
|
s.initialized = make(chan struct{})
|
|
|
|
|
gob.Register(ParameterWithConfig{})
|
|
|
|
|
gob.Register(checkpoint{})
|
|
|
|
|
|
|
|
|
|
if cp != nil {
|
|
|
|
|
for _, item := range cp {
|
|
|
|
|
p := item.ParamConfig
|
|
|
|
|
st := item.State
|
|
|
|
|
s.optMap[p.Param.Name] = newOptimizer(p, st)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return s, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Save tells the parameter server to save parameters.
|
|
|
|
|
func (s *Service) Save(path string, dummy *int) error {
|
|
|
|
|
//FIXME: checkpoint is only used by pserver
|
|
|
|
|
// and has a constant path of */checkpoints/{pserver_idx}*
|
|
|
|
|
// pserver save checkpoint
|
|
|
|
|
func (s *Service) doCheckpoint() error {
|
|
|
|
|
<-s.initialized
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
var paramWithConfig ParameterWithConfig
|
|
|
|
|
|
|
|
|
|
cp := make([]parameterCheckPoint, 0, len(s.optMap))
|
|
|
|
|
index := 0
|
|
|
|
|
for name, opt := range s.optMap {
|
|
|
|
|
paramWithConfig.Param.Name = name
|
|
|
|
|
paramWithConfig.Param.ElementType = opt.elementType
|
|
|
|
|
paramWithConfig.Param.Content = opt.GetWeights()
|
|
|
|
|
paramWithConfig.State = opt.GetStates()
|
|
|
|
|
content, err := GetBytes(paramWithConfig)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
ck := checkpoint{}
|
|
|
|
|
h := md5.New()
|
|
|
|
|
ck.Md5sum = hex.EncodeToString(h.Sum(content))
|
|
|
|
|
ck.Timestamp = time.Now().String()
|
|
|
|
|
ck.Uuid = checkpoint_path + strconv.Itoa(s.idx)
|
|
|
|
|
ckbytes, err := GetBytes(ck)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
// TODO: according design doc, need to save Uuid to etcd in json format
|
|
|
|
|
// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx}
|
|
|
|
|
log.Infof("parameter checkpoint %s", ckbytes)
|
|
|
|
|
|
|
|
|
|
if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) {
|
|
|
|
|
log.Info("checkpoint not exists.")
|
|
|
|
|
} else {
|
|
|
|
|
err = os.Remove(ck.Uuid)
|
|
|
|
|
log.Infof("remove %s", ck.Uuid)
|
|
|
|
|
}
|
|
|
|
|
f, err := os.Create(ck.Uuid)
|
|
|
|
|
defer f.Close()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
writer := bufio.NewWriter(f)
|
|
|
|
|
_, err = writer.Write(content)
|
|
|
|
|
writer.Flush()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
var pc parameterCheckPoint
|
|
|
|
|
pc.ParamConfig.Param.Name = name
|
|
|
|
|
pc.ParamConfig.Param.ElementType = opt.elementType
|
|
|
|
|
pc.ParamConfig.Param.Content = opt.GetWeights()
|
|
|
|
|
pc.State = opt.GetStates()
|
|
|
|
|
cp[index] = pc
|
|
|
|
|
index++
|
|
|
|
|
}
|
|
|
|
|
var buf bytes.Buffer
|
|
|
|
|
encoder := gob.NewEncoder(&buf)
|
|
|
|
|
err := encoder.Encode(cp)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cpMeta := checkpointMeta{}
|
|
|
|
|
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
|
|
|
|
|
cpMeta.Timestamp = time.Now().String()
|
|
|
|
|
h := md5.New()
|
|
|
|
|
cpMeta.Md5sum = h.Sum(buf.Bytes())
|
|
|
|
|
|
|
|
|
|
cpMetajson, err := json.Marshal(cpMeta)
|
|
|
|
|
s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
|
|
|
|
|
log.Info("checkpoint does not exists.")
|
|
|
|
|
} else {
|
|
|
|
|
err = os.Remove(cpMeta.UUID)
|
|
|
|
|
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
|
|
|
|
|
}
|
|
|
|
|
f, err := os.Create(cpMeta.UUID)
|
|
|
|
|
defer f.Close()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
writer := bufio.NewWriter(f)
|
|
|
|
|
_, err = writer.Write(buf.Bytes())
|
|
|
|
|
writer.Flush()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|