|
|
|
@ -5,6 +5,7 @@ import (
|
|
|
|
|
"bytes"
|
|
|
|
|
"crypto/md5"
|
|
|
|
|
"encoding/gob"
|
|
|
|
|
"encoding/hex"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
|
|
|
@ -67,30 +68,19 @@ type checkpointMeta struct {
|
|
|
|
|
type Checkpoint []parameterCheckPoint
|
|
|
|
|
|
|
|
|
|
// Gradient is the gradient of the parameter.
|
|
|
|
|
type Gradient Parameter
|
|
|
|
|
|
|
|
|
|
// Service is the RPC service for pserver.
|
|
|
|
|
type Service struct {
|
|
|
|
|
initialized chan struct{}
|
|
|
|
|
idx int
|
|
|
|
|
checkpointInterval int
|
|
|
|
|
checkpointInterval time.Duration
|
|
|
|
|
checkpointPath string
|
|
|
|
|
client *EtcdClient
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
optMap map[string]*optimizer
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// //serialize ParameterWithConfig to byte stream
|
|
|
|
|
// func GetBytes(content ...interface{}) ([]byte, error) {
|
|
|
|
|
|
|
|
|
|
// var buf bytes.Buffer
|
|
|
|
|
// encoder := gob.NewEncoder(&buf)
|
|
|
|
|
// err := encoder.Encode(content)
|
|
|
|
|
// if err != nil {
|
|
|
|
|
// return nil, err
|
|
|
|
|
// }
|
|
|
|
|
// return buf.Bytes(), nil
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
// NewService creates a new service, will bypass etcd registration if no
|
|
|
|
|
// endpoints specified.
|
|
|
|
|
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
|
|
|
|
@ -129,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
|
|
|
|
|
// TODO(helin): check if paramWithConfigs.Param.Content is
|
|
|
|
|
// properly memory aligned, if not, make copy to a memory
|
|
|
|
|
// aligned region.
|
|
|
|
|
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
|
|
|
|
|
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -218,7 +208,7 @@ func (s *Service) doCheckpoint() error {
|
|
|
|
|
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
|
|
|
|
|
cpMeta.Timestamp = time.Now().String()
|
|
|
|
|
h := md5.New()
|
|
|
|
|
cpMeta.Md5sum = h.Sum(buf.Bytes())
|
|
|
|
|
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
|
|
|
|
|
|
|
|
|
|
cpMetajson, err := json.Marshal(cpMeta)
|
|
|
|
|
s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
|
|
|
|
|