parent
ef038743f1
commit
1777017a05
@ -1,3 +0,0 @@
|
||||
vendor/
|
||||
.glide/
|
||||
proto/*.go
|
@ -1,23 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
add_subdirectory(pserver/client/c)
|
||||
add_subdirectory(cmd/pserver)
|
||||
add_subdirectory(cmd/master)
|
||||
add_subdirectory(master/c)
|
||||
add_subdirectory(master)
|
||||
add_subdirectory(pserver)
|
||||
add_subdirectory(pserver/client)
|
||||
add_subdirectory(utils/networkhelper)
|
@ -1,15 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
go_binary(master SRC master.go)
|
@ -1,120 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/inconshreveable/log15"
|
||||
"github.com/namsral/flag"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 8080, "port of the master server.")
|
||||
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
|
||||
endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
|
||||
taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.")
|
||||
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
|
||||
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
|
||||
logLevel := flag.String("log-level", "info",
|
||||
"log level, possible values: debug, info, warn, error, crit")
|
||||
flag.Parse()
|
||||
|
||||
lvl, err := log.LvlFromString(*logLevel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
log.Root().SetHandler(
|
||||
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
|
||||
)
|
||||
|
||||
if *endpoints == "" {
|
||||
log.Warn("-endpoints not set, fault tolerance not be enabled.")
|
||||
}
|
||||
|
||||
var store master.Store
|
||||
if *endpoints != "" {
|
||||
eps := strings.Split(*endpoints, ",")
|
||||
ip, err := networkhelper.GetExternalIP()
|
||||
if err != nil {
|
||||
log.Crit("get external ip error", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", ip, *port)
|
||||
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
|
||||
if err != nil {
|
||||
log.Crit("error creating etcd client.", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
store = &master.InMemStore{}
|
||||
}
|
||||
|
||||
shutdown := func() {
|
||||
log.Info("shutting down gracefully")
|
||||
err := store.Shutdown()
|
||||
if err != nil {
|
||||
log.Error("shutdown error", log.Ctx{"error": err})
|
||||
}
|
||||
}
|
||||
|
||||
// Guaranteed to run even panic happens.
|
||||
defer shutdown()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
|
||||
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
|
||||
if err != nil {
|
||||
log.Crit("error creating new service.", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = rpc.Register(s)
|
||||
if err != nil {
|
||||
log.Crit("error registering to etcd.", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rpc.HandleHTTP()
|
||||
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
||||
if err != nil {
|
||||
log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = http.Serve(l, nil)
|
||||
if err != nil {
|
||||
log.Crit("error serving HTTP", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-c
|
||||
}
|
@ -1 +0,0 @@
|
||||
pserver
|
@ -1,15 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
go_binary(pserver SRCS pserver.go DEPS paddle_go_optimizer)
|
@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/namsral/flag"
|
||||
"github.com/topicai/candy"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/pserver"
|
||||
log "github.com/inconshreveable/log15"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 8001, "port of the pserver")
|
||||
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
|
||||
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
|
||||
"comma separated endpoint string for pserver to connect to etcd")
|
||||
dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
|
||||
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
|
||||
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
|
||||
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
|
||||
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
|
||||
logLevel := flag.String("log-level", "info",
|
||||
"log level, possible values: debug, info, warn, error, crit")
|
||||
flag.Parse()
|
||||
|
||||
lvl, err := log.LvlFromString(*logLevel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
log.Root().SetHandler(
|
||||
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
|
||||
)
|
||||
|
||||
var idx int
|
||||
|
||||
var cp pserver.Checkpoint
|
||||
var e *pserver.EtcdClient
|
||||
if *index >= 0 {
|
||||
idx = *index
|
||||
} else {
|
||||
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
|
||||
idx, err = e.Register(*port)
|
||||
candy.Must(err)
|
||||
|
||||
cp, err = pserver.LoadCheckpoint(e, idx)
|
||||
if err != nil {
|
||||
if err == pserver.ErrCheckpointNotFound {
|
||||
log.Info("load checkpoint error", "error", err)
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shutdown := func() {
|
||||
log.Info("shutting down gracefully")
|
||||
sErr := e.Shutdown()
|
||||
if sErr != nil {
|
||||
log.Error("error shutting down", log.Ctx{"error": sErr})
|
||||
}
|
||||
}
|
||||
|
||||
// Guaranteed to run even panic happens.
|
||||
defer shutdown()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
|
||||
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
|
||||
candy.Must(err)
|
||||
|
||||
err = rpc.Register(s)
|
||||
candy.Must(err)
|
||||
|
||||
rpc.HandleHTTP()
|
||||
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
||||
candy.Must(err)
|
||||
|
||||
go func() {
|
||||
log.Info("serving pserver", log.Ctx{"port": *port})
|
||||
err = http.Serve(l, nil)
|
||||
candy.Must(err)
|
||||
}()
|
||||
|
||||
<-c
|
||||
}
|
@ -1,120 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package connection
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TODO(helin): add TCP re-connect logic
|
||||
|
||||
// Conn is a connection to a parameter server
|
||||
type Conn struct {
|
||||
mu sync.Mutex
|
||||
client *rpc.Client
|
||||
waitConn chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new connection.
|
||||
func New() *Conn {
|
||||
c := &Conn{}
|
||||
return c
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// Connect connects the connection to a address.
|
||||
func (c *Conn) Connect(addr string) error {
|
||||
c.mu.Lock()
|
||||
if c.client != nil {
|
||||
err := c.client.Close()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
c.client = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
client, err := rpc.DialHTTP("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
c.client = client
|
||||
if c.waitConn != nil {
|
||||
close(c.waitConn)
|
||||
c.waitConn = nil
|
||||
}
|
||||
} else {
|
||||
err := client.Close()
|
||||
if err != nil {
|
||||
log.Errorln(err)
|
||||
}
|
||||
|
||||
return errors.New("client already set from a concurrent goroutine")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(helin): refactor Call to be able to perform given retry
|
||||
// policy.
|
||||
|
||||
// Call make a RPC call.
|
||||
//
|
||||
// Call will be blocked until the connection to remote RPC service
|
||||
// being established.
|
||||
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
|
||||
c.mu.Lock()
|
||||
client := c.client
|
||||
var waitCh chan struct{}
|
||||
if client == nil {
|
||||
if c.waitConn != nil {
|
||||
waitCh = c.waitConn
|
||||
} else {
|
||||
waitCh = make(chan struct{})
|
||||
c.waitConn = waitCh
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if waitCh != nil {
|
||||
// wait until new connection being established
|
||||
<-waitCh
|
||||
return c.Call(serviceMethod, args, reply)
|
||||
}
|
||||
|
||||
return client.Call(serviceMethod, args, reply)
|
||||
}
|
@ -1,233 +0,0 @@
|
||||
hash: 107c058cf5c9163a75d40eef2273a793c36112683c25d72aa8288827fdde3a19
|
||||
updated: 2017-10-30T03:46:19.137696069Z
|
||||
imports:
|
||||
- name: github.com/alecthomas/gometalinter
|
||||
version: bae2f1293d092fd8167939d5108d1b025eaef9de
|
||||
- name: github.com/beorn7/perks
|
||||
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
|
||||
subpackages:
|
||||
- quantile
|
||||
- name: github.com/boltdb/bolt
|
||||
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
|
||||
- name: github.com/cockroachdb/cmux
|
||||
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
|
||||
- name: github.com/coreos/etcd
|
||||
version: f1d7dd87da3e8feab4aaf675b8e29c6a5ed5f58b
|
||||
subpackages:
|
||||
- alarm
|
||||
- auth
|
||||
- auth/authpb
|
||||
- client
|
||||
- clientv3
|
||||
- clientv3/concurrency
|
||||
- compactor
|
||||
- discovery
|
||||
- embed
|
||||
- error
|
||||
- etcdserver
|
||||
- etcdserver/api
|
||||
- etcdserver/api/etcdhttp
|
||||
- etcdserver/api/v2http
|
||||
- etcdserver/api/v2http/httptypes
|
||||
- etcdserver/api/v3client
|
||||
- etcdserver/api/v3election
|
||||
- etcdserver/api/v3election/v3electionpb
|
||||
- etcdserver/api/v3election/v3electionpb/gw
|
||||
- etcdserver/api/v3lock
|
||||
- etcdserver/api/v3lock/v3lockpb
|
||||
- etcdserver/api/v3lock/v3lockpb/gw
|
||||
- etcdserver/api/v3rpc
|
||||
- etcdserver/api/v3rpc/rpctypes
|
||||
- etcdserver/auth
|
||||
- etcdserver/etcdserverpb
|
||||
- etcdserver/etcdserverpb/gw
|
||||
- etcdserver/membership
|
||||
- etcdserver/stats
|
||||
- lease
|
||||
- lease/leasehttp
|
||||
- lease/leasepb
|
||||
- mvcc
|
||||
- mvcc/backend
|
||||
- mvcc/mvccpb
|
||||
- pkg/adt
|
||||
- pkg/contention
|
||||
- pkg/cors
|
||||
- pkg/cpuutil
|
||||
- pkg/crc
|
||||
- pkg/debugutil
|
||||
- pkg/fileutil
|
||||
- pkg/httputil
|
||||
- pkg/idutil
|
||||
- pkg/ioutil
|
||||
- pkg/logutil
|
||||
- pkg/monotime
|
||||
- pkg/netutil
|
||||
- pkg/pathutil
|
||||
- pkg/pbutil
|
||||
- pkg/runtime
|
||||
- pkg/schedule
|
||||
- pkg/srv
|
||||
- pkg/tlsutil
|
||||
- pkg/transport
|
||||
- pkg/types
|
||||
- pkg/wait
|
||||
- proxy/grpcproxy/adapter
|
||||
- raft
|
||||
- raft/raftpb
|
||||
- rafthttp
|
||||
- snap
|
||||
- snap/snappb
|
||||
- store
|
||||
- version
|
||||
- wal
|
||||
- wal/walpb
|
||||
- name: github.com/coreos/go-semver
|
||||
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
|
||||
subpackages:
|
||||
- semver
|
||||
- name: github.com/coreos/go-systemd
|
||||
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
|
||||
subpackages:
|
||||
- daemon
|
||||
- journal
|
||||
- util
|
||||
- name: github.com/coreos/pkg
|
||||
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
|
||||
subpackages:
|
||||
- capnslog
|
||||
- name: github.com/dgrijalva/jwt-go
|
||||
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
|
||||
- name: github.com/ghodss/yaml
|
||||
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
|
||||
- name: github.com/go-stack/stack
|
||||
version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf
|
||||
- name: github.com/gogo/protobuf
|
||||
version: 909568be09de550ed094403c2bf8a261b5bb730a
|
||||
subpackages:
|
||||
- proto
|
||||
- name: github.com/golang/protobuf
|
||||
version: 4bd1920723d7b7c925de087aa32e2187708897f7
|
||||
subpackages:
|
||||
- jsonpb
|
||||
- proto
|
||||
- name: github.com/golang/snappy
|
||||
version: 553a641470496b2327abcac10b36396bd98e45c9
|
||||
- name: github.com/google/btree
|
||||
version: 925471ac9e2131377a91e1595defec898166fe49
|
||||
- name: github.com/grpc-ecosystem/go-grpc-prometheus
|
||||
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
|
||||
- name: github.com/grpc-ecosystem/grpc-gateway
|
||||
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
|
||||
subpackages:
|
||||
- runtime
|
||||
- runtime/internal
|
||||
- utilities
|
||||
- name: github.com/inconshreveable/log15
|
||||
version: 0decfc6c20d9ca0ad143b0e89dcaa20f810b4fb3
|
||||
- name: github.com/jonboulle/clockwork
|
||||
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
|
||||
- name: github.com/mattn/go-colorable
|
||||
version: 5411d3eea5978e6cdc258b30de592b60df6aba96
|
||||
- name: github.com/mattn/go-isatty
|
||||
version: 57fdcb988a5c543893cc61bce354a6e24ab70022
|
||||
- name: github.com/matttproud/golang_protobuf_extensions
|
||||
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
|
||||
subpackages:
|
||||
- pbutil
|
||||
- name: github.com/namsral/flag
|
||||
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
|
||||
- name: github.com/PaddlePaddle/recordio
|
||||
version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
|
||||
- name: github.com/prometheus/client_golang
|
||||
version: c5b7fccd204277076155f10851dad72b76a49317
|
||||
subpackages:
|
||||
- prometheus
|
||||
- name: github.com/prometheus/client_model
|
||||
version: 6f3806018612930941127f2a7c6c453ba2c527d2
|
||||
subpackages:
|
||||
- go
|
||||
- name: github.com/prometheus/common
|
||||
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
|
||||
subpackages:
|
||||
- expfmt
|
||||
- internal/bitbucket.org/ww/goautoneg
|
||||
- model
|
||||
- name: github.com/prometheus/procfs
|
||||
version: a1dba9ce8baed984a2495b658c82687f8157b98f
|
||||
subpackages:
|
||||
- xfs
|
||||
- name: github.com/satori/go.uuid
|
||||
version: 879c5887cd475cd7864858769793b2ceb0d44feb
|
||||
- name: github.com/sirupsen/logrus
|
||||
version: f006c2ac4710855cf0f916dd6b77acf6b048dc6e
|
||||
- name: github.com/topicai/candy
|
||||
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
|
||||
- name: github.com/ugorji/go
|
||||
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
|
||||
subpackages:
|
||||
- codec
|
||||
- name: github.com/xiang90/probing
|
||||
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
|
||||
- name: golang.org/x/crypto
|
||||
version: 9419663f5a44be8b34ca85f08abc5fe1be11f8a3
|
||||
repo: https://github.com/golang/crypto.git
|
||||
vcs: git
|
||||
subpackages:
|
||||
- bcrypt
|
||||
- blowfish
|
||||
- ssh/terminal
|
||||
- name: golang.org/x/net
|
||||
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
|
||||
subpackages:
|
||||
- context
|
||||
- http2
|
||||
- http2/hpack
|
||||
- idna
|
||||
- internal/timeseries
|
||||
- lex/httplex
|
||||
- trace
|
||||
- name: golang.org/x/sys
|
||||
version: e48874b42435b4347fc52bdee0424a52abc974d7
|
||||
repo: https://github.com/golang/sys.git
|
||||
vcs: git
|
||||
subpackages:
|
||||
- unix
|
||||
- windows
|
||||
- name: golang.org/x/text
|
||||
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
|
||||
repo: https://github.com/golang/text.git
|
||||
vcs: git
|
||||
subpackages:
|
||||
- secure/bidirule
|
||||
- transform
|
||||
- unicode/bidi
|
||||
- unicode/norm
|
||||
- name: google.golang.org/grpc
|
||||
version: 8050b9cbc271307e5a716a9d782803d09b0d6f2d
|
||||
subpackages:
|
||||
- codes
|
||||
- credentials
|
||||
- grpclog
|
||||
- internal
|
||||
- keepalive
|
||||
- metadata
|
||||
- naming
|
||||
- peer
|
||||
- stats
|
||||
- tap
|
||||
- transport
|
||||
- name: gopkg.in/yaml.v2
|
||||
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
|
||||
testImports:
|
||||
- name: github.com/davecgh/go-spew
|
||||
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
|
||||
subpackages:
|
||||
- spew
|
||||
- name: github.com/pmezard/go-difflib
|
||||
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
|
||||
subpackages:
|
||||
- difflib
|
||||
- name: github.com/stretchr/testify
|
||||
version: 05e8a0eda380579888eb53c394909df027f06991
|
||||
subpackages:
|
||||
- assert
|
@ -1,33 +0,0 @@
|
||||
package: github.com/PaddlePaddle/Paddle/go
|
||||
import:
|
||||
- package: github.com/PaddlePaddle/recordio
|
||||
- package: github.com/coreos/etcd
|
||||
version: ^3.2.1
|
||||
subpackages:
|
||||
- clientv3
|
||||
- clientv3/concurrency
|
||||
- embed
|
||||
- etcdserver
|
||||
- package: github.com/namsral/flag
|
||||
version: ^1.7.4-pre
|
||||
- package: github.com/sirupsen/logrus
|
||||
version: ^1.0.0
|
||||
- package: github.com/topicai/candy
|
||||
- package: golang.org/x/crypto
|
||||
repo: https://github.com/golang/crypto.git
|
||||
vcs: git
|
||||
- package: golang.org/x/sys
|
||||
repo: https://github.com/golang/sys.git
|
||||
vcs: git
|
||||
- package: golang.org/x/text
|
||||
repo: https://github.com/golang/text.git
|
||||
vcs: git
|
||||
- package: github.com/satori/go.uuid
|
||||
version: v1.1.0
|
||||
- package: github.com/alecthomas/gometalinter
|
||||
version: v1.2.1
|
||||
- package: github.com/inconshreveable/log15
|
||||
version: v2.13
|
||||
- package: github.com/go-stack/stack
|
||||
version: v1.6.0
|
||||
- package: github.com/golang/protobuf
|
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
if(WITH_TESTING)
|
||||
go_test(master_test)
|
||||
endif()
|
@ -1,15 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
go_library(paddle_master SHARED DEPS paddle_go_optimizer)
|
@ -1,196 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#define PADDLE_MASTER_OK 0
|
||||
#define PADDLE_MASTER_ERROR -1
|
||||
|
||||
#define PADDLE_SAVE_MODEL_OK 1
|
||||
#define PADDLE_SAVE_MODEL_SKIP 0
|
||||
|
||||
typedef int paddle_master_client;
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
log "github.com/inconshreveable/log15"
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
var handleMap = make(map[C.paddle_master_client]*master.Client)
|
||||
var curHandle C.paddle_master_client
|
||||
|
||||
func init() {
|
||||
log.Root().SetHandler(
|
||||
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
|
||||
)
|
||||
}
|
||||
|
||||
func add(c *master.Client) C.paddle_master_client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
client := curHandle
|
||||
curHandle++
|
||||
handleMap[client] = c
|
||||
return client
|
||||
}
|
||||
|
||||
func get(client C.paddle_master_client) *master.Client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return handleMap[client]
|
||||
}
|
||||
|
||||
func remove(client C.paddle_master_client) *master.Client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
h := handleMap[client]
|
||||
delete(handleMap, client)
|
||||
return h
|
||||
}
|
||||
|
||||
//export paddle_new_etcd_master_client
|
||||
//
|
||||
// bufSize is the record buffer size.
|
||||
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
|
||||
p := C.GoString(etcdEndpoints)
|
||||
endpoints := strings.Split(p, ",")
|
||||
c, err := master.NewClient(
|
||||
master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
|
||||
master.WithBuffer(bufSize),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return add(c)
|
||||
}
|
||||
|
||||
//export paddle_new_master_client
|
||||
//
|
||||
// bufSize is the record buffer size.
|
||||
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
|
||||
a := C.GoString(addr)
|
||||
c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return add(c)
|
||||
}
|
||||
|
||||
//export paddle_release_master_client
|
||||
func paddle_release_master_client(client C.paddle_master_client) {
|
||||
remove(client)
|
||||
}
|
||||
|
||||
//export paddle_start_get_records
|
||||
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
|
||||
c := get(client)
|
||||
c.StartGetRecords(int(pass))
|
||||
}
|
||||
|
||||
//export paddle_set_dataset
|
||||
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
|
||||
c := get(client)
|
||||
var paths []string
|
||||
for i := 0; i < int(size); i++ {
|
||||
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
|
||||
str := C.GoString(*ptr)
|
||||
paths = append(paths, str)
|
||||
}
|
||||
err := c.SetDataset(paths)
|
||||
if err != nil {
|
||||
log.Error("error set dataset",
|
||||
log.Ctx{"error": err, "paths": paths})
|
||||
return C.PADDLE_MASTER_ERROR
|
||||
}
|
||||
|
||||
return C.PADDLE_MASTER_OK
|
||||
}
|
||||
|
||||
// paddle_next_record gets the nexts training record.
|
||||
//
|
||||
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
|
||||
//
|
||||
//export paddle_next_record
|
||||
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
|
||||
c := get(client)
|
||||
r, err := c.NextRecord()
|
||||
if err != nil {
|
||||
// NOTE: use errors to indicate pass ends
|
||||
if err.Error() == master.ErrAllTaskFailed.Error() ||
|
||||
err.Error() == master.ErrNoMoreAvailable.Error() ||
|
||||
err.Error() == master.ErrPassBefore.Error() {
|
||||
return -2
|
||||
}
|
||||
*record = (*C.uchar)(nil)
|
||||
return -1
|
||||
}
|
||||
|
||||
if len(r) == 0 {
|
||||
// Empty record
|
||||
*record = (*C.uchar)(nil)
|
||||
return 0
|
||||
}
|
||||
|
||||
size := C.size_t(len(r))
|
||||
*record = (*C.uchar)(C.malloc(size))
|
||||
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
|
||||
return C.int(size)
|
||||
}
|
||||
|
||||
// paddle_request_save_model requests the master server to approve the
|
||||
// caller to save the model.
|
||||
//
|
||||
// returns 1 if the save the model request is approved, 0 if the
|
||||
// request is rejected because other trainer is saving the model, -1
|
||||
// if error happened.
|
||||
//
|
||||
//export paddle_request_save_model
|
||||
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
|
||||
c := get(client)
|
||||
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
|
||||
if err != nil {
|
||||
log.Error("error request save model", log.Ctx{"error": err})
|
||||
return C.PADDLE_MASTER_ERROR
|
||||
}
|
||||
|
||||
if need {
|
||||
return C.PADDLE_SAVE_MODEL_OK
|
||||
}
|
||||
|
||||
return C.PADDLE_SAVE_MODEL_SKIP
|
||||
}
|
||||
|
||||
//export mem_free
|
||||
func mem_free(p unsafe.Pointer) {
|
||||
// "free" may be a better name for this function, but doing so
|
||||
// will cause calling any function of this library from Python
|
||||
// ctypes hanging.
|
||||
C.free(p)
|
||||
}
|
||||
|
||||
func main() {}
|
File diff suppressed because it is too large
Load Diff
@ -1,152 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/connection"
|
||||
"github.com/PaddlePaddle/recordio"
|
||||
)
|
||||
|
||||
const (
|
||||
totalTask = 20
|
||||
chunkPerTask = 10
|
||||
)
|
||||
|
||||
func TestGetFinishTask(t *testing.T) {
|
||||
const path = "/tmp/master_client_test_0"
|
||||
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func(l net.Listener) {
|
||||
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
|
||||
if sErr != nil {
|
||||
panic(sErr)
|
||||
}
|
||||
|
||||
server := rpc.NewServer()
|
||||
sErr = server.Register(s)
|
||||
if sErr != nil {
|
||||
panic(sErr)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
sErr = http.Serve(l, mux)
|
||||
if sErr != nil {
|
||||
panic(sErr)
|
||||
}
|
||||
}(l)
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i := 0; i < totalTask*chunkPerTask; i++ {
|
||||
w := recordio.NewWriter(f, -1, -1)
|
||||
_, err = w.Write(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// call Close to force RecordIO writing a chunk.
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Manually intialize client to avoid calling c.getRecords()
|
||||
c := &Client{}
|
||||
c.conn = connection.New()
|
||||
addr := fmt.Sprintf(":%d", p)
|
||||
ch := make(chan string, 1)
|
||||
ch <- addr
|
||||
go c.monitorMaster(ch)
|
||||
|
||||
err = c.SetDataset([]string{path})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
checkOnePass := func(i int) {
|
||||
var tasks []Task
|
||||
for idx := 0; idx < totalTask; idx++ {
|
||||
task, cErr := c.getTask(i)
|
||||
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
|
||||
t.Fatalf("error: %v, pass: %d\n", cErr, i)
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
// getting task before task finishes should return error
|
||||
_, cErr := c.getTask(i)
|
||||
if cErr == nil {
|
||||
t.Fatalf("Should get error, pass: %d\n", i)
|
||||
}
|
||||
|
||||
cErr = c.taskFinished(tasks[0].Meta.ID)
|
||||
if cErr != nil {
|
||||
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
|
||||
}
|
||||
// call taskFailed once won't put the task to failed queue, just ensure
|
||||
// the call
|
||||
cErr = c.taskFailed(tasks[0].Meta)
|
||||
if cErr != nil {
|
||||
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
|
||||
}
|
||||
|
||||
tasks = tasks[1:]
|
||||
_, cErr = c.getTask(i)
|
||||
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
|
||||
t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
|
||||
}
|
||||
|
||||
for _, task := range tasks {
|
||||
cErr = c.taskFinished(task.Meta.ID)
|
||||
if cErr != nil {
|
||||
t.Fatal(cErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
// init pass data
|
||||
c.StartGetRecords(i)
|
||||
checkOnePass(i)
|
||||
}
|
||||
}
|
@ -1,150 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
"github.com/PaddlePaddle/recordio"
|
||||
)
|
||||
|
||||
// tool function for testing output goroutine ids
|
||||
func goid() int {
|
||||
var buf [64]byte
|
||||
n := runtime.Stack(buf[:], false)
|
||||
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
|
||||
id, err := strconv.Atoi(idField)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func TestNextRecord(t *testing.T) {
|
||||
const (
|
||||
path = "/tmp/master_client_TestFull"
|
||||
total = 50
|
||||
)
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func(l net.Listener) {
|
||||
s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
server := rpc.NewServer()
|
||||
err = server.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
err = http.Serve(l, mux)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(l)
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
w := recordio.NewWriter(f, 1, -1)
|
||||
for i := 0; i < total; i++ {
|
||||
_, err = w.Write([]byte{byte(i)})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// start several client to test task fetching
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 4; i++ {
|
||||
wg.Add(1)
|
||||
// test for multiple concurrent clients
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// each go-routine needs a single client connection instance
|
||||
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
e = c.SetDataset([]string{path})
|
||||
if e != nil {
|
||||
panic(e)
|
||||
}
|
||||
|
||||
// test for n passes
|
||||
for pass := 0; pass < 10; pass++ {
|
||||
c.StartGetRecords(pass)
|
||||
|
||||
received := make(map[byte]bool)
|
||||
taskid := 0
|
||||
for {
|
||||
r, e := c.NextRecord()
|
||||
if e != nil {
|
||||
// ErrorPassAfter will wait, else break for next pass
|
||||
if e.Error() == master.ErrPassBefore.Error() ||
|
||||
e.Error() == master.ErrNoMoreAvailable.Error() {
|
||||
break
|
||||
}
|
||||
t.Fatal(pass, taskid, "Read error:", e)
|
||||
}
|
||||
if len(r) != 1 {
|
||||
t.Fatal(pass, taskid, "Length should be 1.", r)
|
||||
}
|
||||
if received[r[0]] {
|
||||
t.Fatal(pass, taskid, "Received duplicate.", received, r)
|
||||
}
|
||||
taskid++
|
||||
received[r[0]] = true
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
@ -1,201 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/coreos/etcd/clientv3/concurrency"
|
||||
log "github.com/inconshreveable/log15"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLockPath is the default etcd master lock path.
|
||||
DefaultLockPath = "/master/lock"
|
||||
// DefaultStatePath is the default etcd key for master state.
|
||||
DefaultStatePath = "/master/state"
|
||||
// DefaultAddrPath is the default etcd key for master address.
|
||||
DefaultAddrPath = "/master/addr"
|
||||
)
|
||||
|
||||
// EtcdClient is the etcd client that the master uses for fault
|
||||
// tolerance and service registry.
|
||||
type EtcdClient struct {
|
||||
lockPath string
|
||||
statePath string
|
||||
client *clientv3.Client
|
||||
lock *concurrency.Mutex
|
||||
sess *concurrency.Session
|
||||
}
|
||||
|
||||
// NewEtcdClient creates a new EtcdClient.
|
||||
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
|
||||
log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: endpoints,
|
||||
DialTimeout: dialTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lock := concurrency.NewMutex(sess, lockPath)
|
||||
// It's fine for the lock to get stuck, in this case we have
|
||||
// multiple master servers running (only configured to have
|
||||
// one master running, but split-brain problem may cause
|
||||
// multiple master servers running), and the cluster management
|
||||
// software will kill one of them.
|
||||
log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
|
||||
err = lock.Lock(context.TODO())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
|
||||
|
||||
put := clientv3.OpPut(addrPath, addr)
|
||||
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Crit("No longer owns the master lock. Exiting.")
|
||||
panic("No longer owns the master lock. Exiting.")
|
||||
}
|
||||
|
||||
e := &EtcdClient{
|
||||
lockPath: lockPath,
|
||||
statePath: statePath,
|
||||
client: cli,
|
||||
lock: lock,
|
||||
sess: sess,
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// Save saves the state into the etcd.
|
||||
func (e *EtcdClient) Save(state []byte) error {
|
||||
ctx := context.TODO()
|
||||
put := clientv3.OpPut(e.statePath, string(state))
|
||||
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Error("No longer owns the lock, trying to lock again")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
err := e.lock.Lock(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
// We lost the master lock and can not acquire
|
||||
// it back, it means some other master is
|
||||
// already started. We don't want cluster
|
||||
// management system to kill the master server
|
||||
// who is holding the lock and running
|
||||
// correctly. So the most feasible solution is
|
||||
// to kill current master server. The current
|
||||
// state is not saved, but the trainer's RPC
|
||||
// call will fail, so the trainer will retry.
|
||||
log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
|
||||
panic("Could not acquire the lock at %s: %v. Exiting.")
|
||||
}
|
||||
log.Info("Successfully acquired lock at %s.", e.lockPath)
|
||||
return e.Save(state)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load loads the state from etcd.
|
||||
func (e *EtcdClient) Load() ([]byte, error) {
|
||||
ctx := context.TODO()
|
||||
get := clientv3.OpGet(e.statePath)
|
||||
|
||||
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Error("No longer owns the lock, trying to lock and load again.")
|
||||
err = e.lock.Lock(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return e.Load()
|
||||
}
|
||||
|
||||
kvs := resp.Responses[0].GetResponseRange().Kvs
|
||||
if len(kvs) == 0 {
|
||||
// No state exists
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
state := kvs[0].Value
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the etcd client gracefully.
|
||||
func (e *EtcdClient) Shutdown() error {
|
||||
err := e.sess.Close()
|
||||
newErr := e.client.Close()
|
||||
if newErr != nil {
|
||||
if err == nil {
|
||||
err = newErr
|
||||
} else {
|
||||
log.Error("shutdown error", log.Ctx{"error": newErr})
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetKey gets the value by the specify key.
|
||||
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
resp, err := c.Get(ctx, key)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
kvs := resp.Kvs
|
||||
if len(kvs) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
v := kvs[0].Value
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
// watchKey watches the specify key and send to valChan if there is some event.
|
||||
func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
|
||||
rch := c.Watch(context.Background(), key)
|
||||
for wresp := range rch {
|
||||
for _, ev := range wresp.Events {
|
||||
// if received event is DELETE, the value will be an empty string
|
||||
log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
|
||||
valChan <- string(ev.Kv.Value)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master
|
||||
|
||||
import "sync"
|
||||
|
||||
// InMemStore is an in memory implementation of Store interface.
|
||||
//
|
||||
// It does not tolerate the fault that causes the program to crash.
|
||||
type InMemStore struct {
|
||||
mu sync.Mutex
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// Save saves the state into the in-memory store.
|
||||
func (m *InMemStore) Save(state []byte) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.buf = state
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load loads the state from the in-memory store.
|
||||
func (m *InMemStore) Load() ([]byte, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.buf, nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the in mem store.
|
||||
func (m *InMemStore) Shutdown() error {
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,52 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPartitionCount(t *testing.T) {
|
||||
cs := make([]Chunk, 100)
|
||||
ts := partition(cs, 5)
|
||||
if len(ts) != 20 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
cs = make([]Chunk, 101)
|
||||
ts = partition(cs, 5)
|
||||
if len(ts) != 21 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
ts = partition(cs, 1)
|
||||
if len(ts) != 101 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
ts = partition(cs, 0)
|
||||
if len(ts) != 101 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartionIndex(t *testing.T) {
|
||||
cs := make([]Chunk, 100)
|
||||
ts := partition(cs, 20)
|
||||
for i := range ts {
|
||||
// test auto increament ids
|
||||
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
|
||||
t.Error(ts[i], i)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
package master_test
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/coreos/etcd/embed"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewServiceWithEtcd(t *testing.T) {
|
||||
// setup an embed etcd server
|
||||
etcdDir, err := ioutil.TempDir("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cfg := embed.NewConfig()
|
||||
lpurl, _ := url.Parse("http://localhost:0")
|
||||
lcurl, _ := url.Parse("http://localhost:0")
|
||||
cfg.LPUrls = []url.URL{*lpurl}
|
||||
cfg.LCUrls = []url.URL{*lcurl}
|
||||
cfg.Dir = etcdDir
|
||||
e, err := embed.StartEtcd(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
e.Close()
|
||||
if err := os.RemoveAll(etcdDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-e.Server.ReadyNotify()
|
||||
|
||||
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
|
||||
endpoint := "127.0.0.1:" + port
|
||||
|
||||
ep := []string{endpoint}
|
||||
masterAddr := "127.0.0.1:3306"
|
||||
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = master.NewService(store, 10, 10, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: ep,
|
||||
DialTimeout: 3 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := cli.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// test master process registry itself into etcd server.
|
||||
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
if(WITH_TESTING)
|
||||
go_test(pserver_test DEPS paddle_go_optimizer gen_proto_go)
|
||||
endif()
|
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
if(WITH_TESTING)
|
||||
go_test(pserver_client_test DEPS paddle_go_optimizer)
|
||||
endif()
|
@ -1 +0,0 @@
|
||||
libpaddle_go_optimizer.a
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue