|  |  |  | @ -5,10 +5,11 @@ import ( | 
			
		
	
		
			
				
					|  |  |  |  | 	"bytes" | 
			
		
	
		
			
				
					|  |  |  |  | 	"crypto/md5" | 
			
		
	
		
			
				
					|  |  |  |  | 	"encoding/gob" | 
			
		
	
		
			
				
					|  |  |  |  | 	"encoding/hex" | 
			
		
	
		
			
				
					|  |  |  |  | 	"encoding/json" | 
			
		
	
		
			
				
					|  |  |  |  | 	"errors" | 
			
		
	
		
			
				
					|  |  |  |  | 	"fmt" | 
			
		
	
		
			
				
					|  |  |  |  | 	"os" | 
			
		
	
		
			
				
					|  |  |  |  | 	"path/filepath" | 
			
		
	
		
			
				
					|  |  |  |  | 	"strconv" | 
			
		
	
		
			
				
					|  |  |  |  | 	"sync" | 
			
		
	
		
			
				
					|  |  |  |  | 	"time" | 
			
		
	
	
		
			
				
					|  |  |  | @ -26,10 +27,6 @@ const ( | 
			
		
	
		
			
				
					|  |  |  |  | 	Uninitialized = "pserver not fully initialized" | 
			
		
	
		
			
				
					|  |  |  |  | ) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | const ( | 
			
		
	
		
			
				
					|  |  |  |  | 	checkpoint_path = "./checkpoints/" | 
			
		
	
		
			
				
					|  |  |  |  | ) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // Supported element types
 | 
			
		
	
		
			
				
					|  |  |  |  | const ( | 
			
		
	
		
			
				
					|  |  |  |  | 	Int32 ElementType = iota | 
			
		
	
	
		
			
				
					|  |  |  | @ -51,49 +48,68 @@ type Parameter struct { | 
			
		
	
		
			
				
					|  |  |  |  | type ParameterWithConfig struct { | 
			
		
	
		
			
				
					|  |  |  |  | 	Param  Parameter | 
			
		
	
		
			
				
					|  |  |  |  | 	Config []byte // parameter configuration in Proto Buffer format
 | 
			
		
	
		
			
				
					|  |  |  |  | 	State  []byte // parameter training state
 | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // Checkpoint of Parameter and State
 | 
			
		
	
		
			
				
					|  |  |  |  | type parameterCheckPoint struct { | 
			
		
	
		
			
				
					|  |  |  |  | 	ParamConfig ParameterWithConfig | 
			
		
	
		
			
				
					|  |  |  |  | 	State       []byte | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // checkpoint signature
 | 
			
		
	
		
			
				
					|  |  |  |  | type checkpointMeta struct { | 
			
		
	
		
			
				
					|  |  |  |  | 	UUID      string `json:"uuid"` | 
			
		
	
		
			
				
					|  |  |  |  | 	Md5sum    string `json:"md5sum"` | 
			
		
	
		
			
				
					|  |  |  |  | 	Timestamp string `json:"timestamp"` | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // Checkpoint is the pserver shard persist in file
 | 
			
		
	
		
			
				
					|  |  |  |  | type Checkpoint []parameterCheckPoint | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // Gradient is the gradient of the parameter.
 | 
			
		
	
		
			
				
					|  |  |  |  | type Gradient Parameter | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // Service is the RPC service for pserver.
 | 
			
		
	
		
			
				
					|  |  |  |  | type Service struct { | 
			
		
	
		
			
				
					|  |  |  |  | 	initialized        chan struct{} | 
			
		
	
		
			
				
					|  |  |  |  | 	idx                int | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 	checkpointInterval int | 
			
		
	
		
			
				
					|  |  |  |  | 	checkpointPath     string | 
			
		
	
		
			
				
					|  |  |  |  | 	client             *EtcdClient | 
			
		
	
		
			
				
					|  |  |  |  | 	mu                 sync.Mutex | 
			
		
	
		
			
				
					|  |  |  |  | 	optMap             map[string]*optimizer | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | type checkpoint struct { | 
			
		
	
		
			
				
					|  |  |  |  | 	Uuid      string | 
			
		
	
		
			
				
					|  |  |  |  | 	Md5sum    string | 
			
		
	
		
			
				
					|  |  |  |  | 	Timestamp string | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | // //serialize ParameterWithConfig to byte stream
 | 
			
		
	
		
			
				
					|  |  |  |  | // func GetBytes(content ...interface{}) ([]byte, error) {
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | //serialize ParameterWithConfig to byte stream
 | 
			
		
	
		
			
				
					|  |  |  |  | func GetBytes(content ...interface{}) ([]byte, error) { | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 	var buf bytes.Buffer | 
			
		
	
		
			
				
					|  |  |  |  | 	encoder := gob.NewEncoder(&buf) | 
			
		
	
		
			
				
					|  |  |  |  | 	err := encoder.Encode(content) | 
			
		
	
		
			
				
					|  |  |  |  | 	if err != nil { | 
			
		
	
		
			
				
					|  |  |  |  | 		return nil, err | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 	return buf.Bytes(), nil | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | // 	var buf bytes.Buffer
 | 
			
		
	
		
			
				
					|  |  |  |  | // 	encoder := gob.NewEncoder(&buf)
 | 
			
		
	
		
			
				
					|  |  |  |  | // 	err := encoder.Encode(content)
 | 
			
		
	
		
			
				
					|  |  |  |  | // 	if err != nil {
 | 
			
		
	
		
			
				
					|  |  |  |  | // 		return nil, err
 | 
			
		
	
		
			
				
					|  |  |  |  | // 	}
 | 
			
		
	
		
			
				
					|  |  |  |  | // 	return buf.Bytes(), nil
 | 
			
		
	
		
			
				
					|  |  |  |  | // }
 | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // NewService creates a new service, will bypass etcd registration if no
 | 
			
		
	
		
			
				
					|  |  |  |  | // endpoints specified.
 | 
			
		
	
		
			
				
					|  |  |  |  | func NewService(idx int) (*Service, error) { | 
			
		
	
		
			
				
					|  |  |  |  | func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { | 
			
		
	
		
			
				
					|  |  |  |  | 	s := &Service{ | 
			
		
	
		
			
				
					|  |  |  |  | 		idx:                idx, | 
			
		
	
		
			
				
					|  |  |  |  | 		checkpointInterval: time.Second * time.Duration(seconds), | 
			
		
	
		
			
				
					|  |  |  |  | 		checkpointPath:     path, | 
			
		
	
		
			
				
					|  |  |  |  | 		client:             client, | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 	s.optMap = make(map[string]*optimizer) | 
			
		
	
		
			
				
					|  |  |  |  | 	s.initialized = make(chan struct{}) | 
			
		
	
		
			
				
					|  |  |  |  | 	gob.Register(ParameterWithConfig{}) | 
			
		
	
		
			
				
					|  |  |  |  | 	gob.Register(checkpoint{}) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 	if cp != nil { | 
			
		
	
		
			
				
					|  |  |  |  | 		for _, item := range cp { | 
			
		
	
		
			
				
					|  |  |  |  | 			p := item.ParamConfig | 
			
		
	
		
			
				
					|  |  |  |  | 			st := item.State | 
			
		
	
		
			
				
					|  |  |  |  | 			s.optMap[p.Param.Name] = newOptimizer(p, st) | 
			
		
	
		
			
				
					|  |  |  |  | 		} | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 	return s, nil | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { | 
			
		
	
		
			
				
					|  |  |  |  | 	return nil | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // Save tells the parameter server to save parameters.
 | 
			
		
	
		
			
				
					|  |  |  |  | func (s *Service) Save(path string, dummy *int) error { | 
			
		
	
		
			
				
					|  |  |  |  | 	//FIXME: checkpoint is only used by pserver
 | 
			
		
	
		
			
				
					|  |  |  |  | 	// and has a constant path of */checkpoints/{pserver_idx}*
 | 
			
		
	
		
			
				
					|  |  |  |  | // pserver save checkpoint
 | 
			
		
	
		
			
				
					|  |  |  |  | func (s *Service) doCheckpoint() error { | 
			
		
	
		
			
				
					|  |  |  |  | 	<-s.initialized | 
			
		
	
		
			
				
					|  |  |  |  | 	s.mu.Lock() | 
			
		
	
		
			
				
					|  |  |  |  | 	defer s.mu.Unlock() | 
			
		
	
		
			
				
					|  |  |  |  | 	var paramWithConfig ParameterWithConfig | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 	cp := make([]parameterCheckPoint, 0, len(s.optMap)) | 
			
		
	
		
			
				
					|  |  |  |  | 	index := 0 | 
			
		
	
		
			
				
					|  |  |  |  | 	for name, opt := range s.optMap { | 
			
		
	
		
			
				
					|  |  |  |  | 		paramWithConfig.Param.Name = name | 
			
		
	
		
			
				
					|  |  |  |  | 		paramWithConfig.Param.ElementType = opt.elementType | 
			
		
	
		
			
				
					|  |  |  |  | 		paramWithConfig.Param.Content = opt.GetWeights() | 
			
		
	
		
			
				
					|  |  |  |  | 		paramWithConfig.State = opt.GetStates() | 
			
		
	
		
			
				
					|  |  |  |  | 		content, err := GetBytes(paramWithConfig) | 
			
		
	
		
			
				
					|  |  |  |  | 		var pc parameterCheckPoint | 
			
		
	
		
			
				
					|  |  |  |  | 		pc.ParamConfig.Param.Name = name | 
			
		
	
		
			
				
					|  |  |  |  | 		pc.ParamConfig.Param.ElementType = opt.elementType | 
			
		
	
		
			
				
					|  |  |  |  | 		pc.ParamConfig.Param.Content = opt.GetWeights() | 
			
		
	
		
			
				
					|  |  |  |  | 		pc.State = opt.GetStates() | 
			
		
	
		
			
				
					|  |  |  |  | 		cp[index] = pc | 
			
		
	
		
			
				
					|  |  |  |  | 		index++ | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 	var buf bytes.Buffer | 
			
		
	
		
			
				
					|  |  |  |  | 	encoder := gob.NewEncoder(&buf) | 
			
		
	
		
			
				
					|  |  |  |  | 	err := encoder.Encode(cp) | 
			
		
	
		
			
				
					|  |  |  |  | 	if err != nil { | 
			
		
	
		
			
				
					|  |  |  |  | 			log.Errorln(err) | 
			
		
	
		
			
				
					|  |  |  |  | 		return err | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 		ck := checkpoint{} | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 	cpMeta := checkpointMeta{} | 
			
		
	
		
			
				
					|  |  |  |  | 	cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) | 
			
		
	
		
			
				
					|  |  |  |  | 	cpMeta.Timestamp = time.Now().String() | 
			
		
	
		
			
				
					|  |  |  |  | 	h := md5.New() | 
			
		
	
		
			
				
					|  |  |  |  | 		ck.Md5sum = hex.EncodeToString(h.Sum(content)) | 
			
		
	
		
			
				
					|  |  |  |  | 		ck.Timestamp = time.Now().String() | 
			
		
	
		
			
				
					|  |  |  |  | 		ck.Uuid = checkpoint_path + strconv.Itoa(s.idx) | 
			
		
	
		
			
				
					|  |  |  |  | 		ckbytes, err := GetBytes(ck) | 
			
		
	
		
			
				
					|  |  |  |  | 	cpMeta.Md5sum = h.Sum(buf.Bytes()) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 	cpMetajson, err := json.Marshal(cpMeta) | 
			
		
	
		
			
				
					|  |  |  |  | 	s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) | 
			
		
	
		
			
				
					|  |  |  |  | 	if err != nil { | 
			
		
	
		
			
				
					|  |  |  |  | 			log.Errorln(err) | 
			
		
	
		
			
				
					|  |  |  |  | 		return err | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 		// TODO: according design doc, need to save Uuid to etcd in json format
 | 
			
		
	
		
			
				
					|  |  |  |  | 		// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx}
 | 
			
		
	
		
			
				
					|  |  |  |  | 		log.Infof("parameter checkpoint %s", ckbytes) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | 		if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) { | 
			
		
	
		
			
				
					|  |  |  |  | 			log.Info("checkpoint not exists.") | 
			
		
	
		
			
				
					|  |  |  |  | 	if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) { | 
			
		
	
		
			
				
					|  |  |  |  | 		log.Info("checkpoint does not exists.") | 
			
		
	
		
			
				
					|  |  |  |  | 	} else { | 
			
		
	
		
			
				
					|  |  |  |  | 			err = os.Remove(ck.Uuid) | 
			
		
	
		
			
				
					|  |  |  |  | 			log.Infof("remove %s", ck.Uuid) | 
			
		
	
		
			
				
					|  |  |  |  | 		err = os.Remove(cpMeta.UUID) | 
			
		
	
		
			
				
					|  |  |  |  | 		log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 		f, err := os.Create(ck.Uuid) | 
			
		
	
		
			
				
					|  |  |  |  | 	f, err := os.Create(cpMeta.UUID) | 
			
		
	
		
			
				
					|  |  |  |  | 	defer f.Close() | 
			
		
	
		
			
				
					|  |  |  |  | 	if err != nil { | 
			
		
	
		
			
				
					|  |  |  |  | 		log.Errorln(err) | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 	writer := bufio.NewWriter(f) | 
			
		
	
		
			
				
					|  |  |  |  | 		_, err = writer.Write(content) | 
			
		
	
		
			
				
					|  |  |  |  | 	_, err = writer.Write(buf.Bytes()) | 
			
		
	
		
			
				
					|  |  |  |  | 	writer.Flush() | 
			
		
	
		
			
				
					|  |  |  |  | 	if err != nil { | 
			
		
	
		
			
				
					|  |  |  |  | 		log.Errorln(err) | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 	} | 
			
		
	
		
			
				
					|  |  |  |  | 	return nil | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
	
		
			
				
					|  |  |  | 
 |