|
|
|
@ -9,6 +9,7 @@ import (
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"os"
|
|
|
|
|
"path/filepath"
|
|
|
|
|
"strconv"
|
|
|
|
@ -21,14 +22,14 @@ import (
|
|
|
|
|
// ElementType is the type of elements of a Parameter.
|
|
|
|
|
type ElementType int
|
|
|
|
|
|
|
|
|
|
// RPC error message.
|
|
|
|
|
const (
|
|
|
|
|
// AlreadyInitialized is true if pserver is initialized
|
|
|
|
|
AlreadyInitialized = "pserver already initialized"
|
|
|
|
|
// Uninitialized is true if pserver not fully initialized
|
|
|
|
|
Uninitialized = "pserver not fully initialized"
|
|
|
|
|
CheckpointMD5Failed = "checkpoint file MD5 validation failed"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Supported element types
|
|
|
|
|
// Supported element types.
|
|
|
|
|
const (
|
|
|
|
|
Int32 ElementType = iota
|
|
|
|
|
UInt32
|
|
|
|
@ -51,21 +52,15 @@ type ParameterWithConfig struct {
|
|
|
|
|
Config []byte // parameter configuration in Proto Buffer format
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ParameterCheckpoint is Parameter and State checkpoint
|
|
|
|
|
type ParameterCheckpoint struct {
|
|
|
|
|
ParamConfig ParameterWithConfig
|
|
|
|
|
State []byte
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// checkpoint signature
|
|
|
|
|
// checkpointMeta saves checkpoint metadata
|
|
|
|
|
type checkpointMeta struct {
|
|
|
|
|
UUID string `json:"uuid"`
|
|
|
|
|
Md5sum string `json:"md5sum"`
|
|
|
|
|
Timestamp string `json:"timestamp"`
|
|
|
|
|
MD5 string `json:"md5"`
|
|
|
|
|
Timestamp int64 `json:"timestamp"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Checkpoint is the pserver shard persist in file
|
|
|
|
|
type Checkpoint []ParameterCheckpoint
|
|
|
|
|
type Checkpoint []parameterCheckpoint
|
|
|
|
|
|
|
|
|
|
// Gradient is the gradient of the parameter.
|
|
|
|
|
type Gradient Parameter
|
|
|
|
@ -81,12 +76,53 @@ type Service struct {
|
|
|
|
|
optMap map[string]*optimizer
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// parameterCheckpoint saves parameter checkpoint
|
|
|
|
|
type parameterCheckpoint struct {
|
|
|
|
|
ParameterWithConfig
|
|
|
|
|
State []byte
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewCheckpointFromFile loads parameters and state from checkpoint file
|
|
|
|
|
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) {
|
|
|
|
|
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var cpMeta checkpointMeta
|
|
|
|
|
if err = json.Unmarshal(v, &cpMeta); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn := filepath.Join(cpPath, cpMeta.UUID)
|
|
|
|
|
if _, err = os.Stat(fn); os.IsNotExist(err) {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
content, err := ioutil.ReadFile(fn)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
h := md5.New()
|
|
|
|
|
md5 := hex.EncodeToString(h.Sum(content))
|
|
|
|
|
if md5 != cpMeta.MD5 {
|
|
|
|
|
return nil, errors.New(CheckpointMD5Failed)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dec := gob.NewDecoder(bytes.NewReader(content))
|
|
|
|
|
cp := &Checkpoint{}
|
|
|
|
|
if err = dec.Decode(cp); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return cp, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewService creates a new service, will bypass etcd registration if no
|
|
|
|
|
// endpoints specified.
|
|
|
|
|
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
|
|
|
|
|
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
|
|
|
|
|
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) {
|
|
|
|
|
s := &Service{
|
|
|
|
|
idx: idx,
|
|
|
|
|
checkpointInterval: time.Second * time.Duration(seconds),
|
|
|
|
|
checkpointInterval: interval,
|
|
|
|
|
checkpointPath: path,
|
|
|
|
|
client: client,
|
|
|
|
|
}
|
|
|
|
@ -94,10 +130,12 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp
|
|
|
|
|
s.initialized = make(chan struct{})
|
|
|
|
|
|
|
|
|
|
if cp != nil {
|
|
|
|
|
for _, item := range cp {
|
|
|
|
|
p := item.ParamConfig
|
|
|
|
|
st := item.State
|
|
|
|
|
s.optMap[p.Param.Name] = newOptimizer(p, st)
|
|
|
|
|
for _, item := range *cp {
|
|
|
|
|
p := ParameterWithConfig{
|
|
|
|
|
Param: item.Param,
|
|
|
|
|
Config: item.Config,
|
|
|
|
|
}
|
|
|
|
|
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return s, nil
|
|
|
|
@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error {
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
cp := make([]ParameterCheckpoint, 0, len(s.optMap))
|
|
|
|
|
cp := make([]parameterCheckpoint, len(s.optMap))
|
|
|
|
|
index := 0
|
|
|
|
|
for name, opt := range s.optMap {
|
|
|
|
|
var pc ParameterCheckpoint
|
|
|
|
|
pc.ParamConfig.Param.Name = name
|
|
|
|
|
pc.ParamConfig.Param.ElementType = opt.elementType
|
|
|
|
|
pc.ParamConfig.Param.Content = opt.GetWeights()
|
|
|
|
|
var pc parameterCheckpoint
|
|
|
|
|
pc.Param.Name = name
|
|
|
|
|
pc.Param.ElementType = opt.elementType
|
|
|
|
|
pc.Param.Content = opt.GetWeights()
|
|
|
|
|
pc.State = opt.GetStates()
|
|
|
|
|
cp[index] = pc
|
|
|
|
|
index++
|
|
|
|
@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error {
|
|
|
|
|
|
|
|
|
|
cpMeta := checkpointMeta{}
|
|
|
|
|
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
|
|
|
|
|
cpMeta.Timestamp = time.Now().String()
|
|
|
|
|
cpMeta.Timestamp = time.Now().UnixNano()
|
|
|
|
|
h := md5.New()
|
|
|
|
|
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
|
|
|
|
|
cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
|
|
|
|
|
|
|
|
|
|
cpMetajson, _ := json.Marshal(cpMeta)
|
|
|
|
|
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
|
|
|
|
|
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
@ -219,8 +257,12 @@ func (s *Service) doCheckpoint() error {
|
|
|
|
|
log.Info("checkpoint does not exists.")
|
|
|
|
|
} else {
|
|
|
|
|
err = os.Remove(cpMeta.UUID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
|
|
|
|
|
} else {
|
|
|
|
|
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
f, err := os.Create(cpMeta.UUID)
|
|
|
|
|
defer f.Close()
|
|
|
|
|
if err != nil {
|
|
|
|
|