|
|
|
@ -25,11 +25,13 @@ import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"os"
|
|
|
|
|
"path/filepath"
|
|
|
|
|
"path"
|
|
|
|
|
"strconv"
|
|
|
|
|
"sync"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
uuid "github.com/satori/go.uuid"
|
|
|
|
|
|
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -44,7 +46,7 @@ var ErrCheckpointNotFound = errors.New("checkpoint not found")
|
|
|
|
|
const (
|
|
|
|
|
AlreadyInitialized = "pserver already initialized"
|
|
|
|
|
Uninitialized = "pserver not fully initialized"
|
|
|
|
|
CheckpointMD5Failed = "checkpoint file MD5 validation failed"
|
|
|
|
|
WrongChecksum = "checkpoint file checksum validation failed"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// Supported element types.
|
|
|
|
@ -73,11 +75,12 @@ type ParameterWithConfig struct {
|
|
|
|
|
// checkpointMeta saves checkpoint metadata
|
|
|
|
|
type checkpointMeta struct {
|
|
|
|
|
UUID string `json:"uuid"`
|
|
|
|
|
Path string `json:"path"`
|
|
|
|
|
MD5 string `json:"md5"`
|
|
|
|
|
Timestamp int64 `json:"timestamp"`
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Checkpoint is the pserver shard persist in file
|
|
|
|
|
// Checkpoint is the pserver shard persist in file.
|
|
|
|
|
type Checkpoint []parameterCheckpoint
|
|
|
|
|
|
|
|
|
|
// Gradient is the gradient of the parameter.
|
|
|
|
@ -90,50 +93,58 @@ type Service struct {
|
|
|
|
|
checkpointInterval time.Duration
|
|
|
|
|
checkpointPath string
|
|
|
|
|
client *EtcdClient
|
|
|
|
|
|
|
|
|
|
mu sync.Mutex
|
|
|
|
|
optMap map[string]*optimizer
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// parameterCheckpoint saves parameter checkpoint
|
|
|
|
|
// parameterCheckpoint saves parameter checkpoint.
|
|
|
|
|
type parameterCheckpoint struct {
|
|
|
|
|
ParameterWithConfig
|
|
|
|
|
State []byte
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NewCheckpointFromFile loads parameters and state from checkpoint file
|
|
|
|
|
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
|
|
|
|
|
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
|
|
|
|
|
func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
|
|
|
|
|
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(v) == 0 {
|
|
|
|
|
return nil, ErrCheckpointNotFound
|
|
|
|
|
err = ErrCheckpointNotFound
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var cpMeta checkpointMeta
|
|
|
|
|
if err = json.Unmarshal(v, &cpMeta); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
if err = json.Unmarshal(v, &meta); err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn := filepath.Join(cpPath, cpMeta.UUID)
|
|
|
|
|
if _, err = os.Stat(fn); os.IsNotExist(err) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// LoadCheckpoint loads checkpoint from file.
|
|
|
|
|
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
|
|
|
|
|
cpMeta, err := loadMeta(e, idx)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
content, err := ioutil.ReadFile(fn)
|
|
|
|
|
|
|
|
|
|
content, err := ioutil.ReadFile(cpMeta.Path)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(helin): change MD5 to CRC since CRC is better for file
|
|
|
|
|
// checksum in our use case (emphasize speed over security).
|
|
|
|
|
h := md5.New()
|
|
|
|
|
md5 := hex.EncodeToString(h.Sum(content))
|
|
|
|
|
if md5 != cpMeta.MD5 {
|
|
|
|
|
return nil, errors.New(CheckpointMD5Failed)
|
|
|
|
|
return nil, errors.New(WrongChecksum)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
dec := gob.NewDecoder(bytes.NewReader(content))
|
|
|
|
|
cp := Checkpoint{}
|
|
|
|
|
if err = dec.Decode(cp); err != nil {
|
|
|
|
|
var cp Checkpoint
|
|
|
|
|
if err = dec.Decode(&cp); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return cp, nil
|
|
|
|
@ -193,6 +204,15 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
close(s.initialized)
|
|
|
|
|
go func() {
|
|
|
|
|
t := time.Tick(s.checkpointInterval)
|
|
|
|
|
for range t {
|
|
|
|
|
err := s.checkpoint()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Errorln(err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -240,23 +260,36 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// pserver save checkpoint
|
|
|
|
|
func (s *Service) doCheckpoint() (err error) {
|
|
|
|
|
<-s.initialized
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
func traceTime(start time.Time, name string) {
|
|
|
|
|
elapsed := time.Since(start)
|
|
|
|
|
log.Infof("%s took %v", name, elapsed)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// checkpoint saves checkpoint to disk.
|
|
|
|
|
//
|
|
|
|
|
// checkpoint should be only called after the parameters are
|
|
|
|
|
// initialized.
|
|
|
|
|
func (s *Service) checkpoint() (err error) {
|
|
|
|
|
log.Infoln("Begin save checkpoint.")
|
|
|
|
|
defer traceTime(time.Now(), "save checkpoint")
|
|
|
|
|
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
cp := make([]parameterCheckpoint, len(s.optMap))
|
|
|
|
|
index := 0
|
|
|
|
|
// TODO(helin): write checkpoint incrementally to reduce memory
|
|
|
|
|
// footprint during checkpoint.
|
|
|
|
|
for name, opt := range s.optMap {
|
|
|
|
|
var pc parameterCheckpoint
|
|
|
|
|
pc.Param.Name = name
|
|
|
|
|
pc.Param.ElementType = opt.elementType
|
|
|
|
|
pc.Param.Content = opt.GetWeights()
|
|
|
|
|
pc.Config = opt.config
|
|
|
|
|
pc.State = opt.GetStates()
|
|
|
|
|
cp[index] = pc
|
|
|
|
|
index++
|
|
|
|
|
}
|
|
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
var buf bytes.Buffer
|
|
|
|
|
encoder := gob.NewEncoder(&buf)
|
|
|
|
|
err = encoder.Encode(cp)
|
|
|
|
@ -264,32 +297,9 @@ func (s *Service) doCheckpoint() (err error) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cpMeta := checkpointMeta{}
|
|
|
|
|
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
|
|
|
|
|
cpMeta.Timestamp = time.Now().UnixNano()
|
|
|
|
|
h := md5.New()
|
|
|
|
|
cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
|
|
|
|
|
|
|
|
|
|
cpMetajson, err := json.Marshal(cpMeta)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
|
|
|
|
|
log.Info("checkpoint does not exists.")
|
|
|
|
|
} else {
|
|
|
|
|
err = os.Remove(cpMeta.UUID)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
|
|
|
|
|
} else {
|
|
|
|
|
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
f, err := os.Create(cpMeta.UUID)
|
|
|
|
|
id := uuid.NewV4().String()
|
|
|
|
|
p := path.Join(s.checkpointPath, id)
|
|
|
|
|
f, err := os.Create(p)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
@ -317,5 +327,43 @@ func (s *Service) doCheckpoint() (err error) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
oldMeta, err := loadMeta(s.client, s.idx)
|
|
|
|
|
if err == ErrCheckpointNotFound {
|
|
|
|
|
log.Infoln("Do not have existing checkpoint.")
|
|
|
|
|
err = nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
h := md5.New()
|
|
|
|
|
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
|
|
|
|
|
cpMeta := checkpointMeta{
|
|
|
|
|
UUID: id,
|
|
|
|
|
Timestamp: time.Now().UnixNano(),
|
|
|
|
|
MD5: md5,
|
|
|
|
|
Path: p,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
json, err := json.Marshal(cpMeta)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if oldMeta.Path != "" {
|
|
|
|
|
rmErr := os.Remove(oldMeta.Path)
|
|
|
|
|
if rmErr != nil {
|
|
|
|
|
// log error, but still treat checkpoint as
|
|
|
|
|
// successful.
|
|
|
|
|
log.Errorln(rmErr)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|