You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/go/pserver/service.go

370 lines
8.2 KiB

// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path"
"strconv"
"sync"
"time"
uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus"
)
// ElementType is the type of elements of a Parameter.
type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")
// RPC error message.
const (
AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized"
WrongChecksum = "checkpoint file checksum validation failed"
)
// Supported element types.
const (
Int32 ElementType = iota
UInt32
Int64
UInt64
Float32
Float64
)
// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
Name string
ElementType ElementType
Content []byte
}
// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
Param Parameter
Config []byte // parameter configuration in Proto Buffer format
}
// checkpointMeta saves checkpoint metadata
type checkpointMeta struct {
UUID string `json:"uuid"`
Path string `json:"path"`
MD5 string `json:"md5"`
Timestamp int64 `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file.
type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter.
type Gradient Parameter
// Service is the RPC service for pserver.
type Service struct {
initialized chan struct{}
idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
}
// parameterCheckpoint saves parameter checkpoint.
type parameterCheckpoint struct {
ParameterWithConfig
State []byte
}
func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil {
return
}
if len(v) == 0 {
err = ErrCheckpointNotFound
return
}
if err = json.Unmarshal(v, &meta); err != nil {
return
}
return
}
// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
cpMeta, err := loadMeta(e, idx)
if err != nil {
return nil, err
}
content, err := ioutil.ReadFile(cpMeta.Path)
if err != nil {
return nil, err
}
// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
return nil, errors.New(WrongChecksum)
}
dec := gob.NewDecoder(bytes.NewReader(content))
var cp Checkpoint
if err = dec.Decode(&cp); err != nil {
return nil, err
}
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
checkpointPath: path,
client: client,
}
s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{})
if cp != nil {
for _, item := range cp {
p := ParameterWithConfig{
Param: item.Param,
Config: item.Config,
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
}
return s, nil
}
// InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
select {
case <-s.initialized:
return errors.New(AlreadyInitialized)
default:
}
// TODO(helin): parse parameter config
s.mu.Lock()
defer s.mu.Unlock()
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
return nil
}
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
func (s *Service) FinishInitParams(_ int, _ *int) error {
select {
case <-s.initialized:
return errors.New(AlreadyInitialized)
default:
}
close(s.initialized)
go func() {
t := time.Tick(s.checkpointInterval)
for range t {
err := s.checkpoint()
if err != nil {
log.Errorln(err)
}
}
}()
return nil
}
// SendGrad sends gradient to parameter servers for parameter
// optimization.
func (s *Service) SendGrad(g Gradient, _ *int) error {
select {
case <-s.initialized:
default:
return errors.New(Uninitialized)
}
s.mu.Lock()
defer s.mu.Unlock()
o, ok := s.optMap[g.Name]
if !ok {
return fmt.Errorf("parameter: %s does not exist", g.Name)
}
return o.UpdateParameter(g)
}
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
opt, ok := s.optMap[name]
if !ok {
return fmt.Errorf("parameter: %s does not exist", name)
}
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
// learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// parameter content.
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
return nil
}
func traceTime(start time.Time, name string) {
elapsed := time.Since(start)
log.Infof("%s took %v", name, elapsed)
}
// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func (s *Service) checkpoint() (err error) {
log.Infoln("Begin save checkpoint.")
defer traceTime(time.Now(), "save checkpoint")
s.mu.Lock()
cp := make([]parameterCheckpoint, len(s.optMap))
index := 0
// TODO(helin): write checkpoint incrementally to reduce memory
// footprint during checkpoint.
for name, opt := range s.optMap {
var pc parameterCheckpoint
pc.Param.Name = name
pc.Param.ElementType = opt.elementType
pc.Param.Content = opt.GetWeights()
pc.Config = opt.config
pc.State = opt.GetStates()
cp[index] = pc
index++
}
s.mu.Unlock()
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err = encoder.Encode(cp)
if err != nil {
return
}
id := uuid.NewV4().String()
p := path.Join(s.checkpointPath, id)
f, err := os.Create(p)
if err != nil {
return
}
defer func() {
closeErr := f.Close()
if closeErr != nil {
if err != nil {
log.Errorln(closeErr)
} else {
// Set closeErr as return value.
err = closeErr
}
}
}()
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
if err != nil {
return
}
err = writer.Flush()
if err != nil {
return
}
oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
log.Infoln("Do not have existing checkpoint.")
err = nil
}
if err != nil {
return
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
MD5: md5,
Path: p,
}
json, err := json.Marshal(cpMeta)
if err != nil {
return
}
err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false)
if err != nil {
return
}
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Errorln(rmErr)
}
}
return
}