merge baidu/develop

cblas_new
qijun 8 years ago
commit 4cc42171db

@ -22,9 +22,11 @@
hooks: hooks:
- id: clang-formater - id: clang-formater
- repo: https://github.com/PaddlePaddle/pre-commit-golang - repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 16398aeccf263adaf53b2495eed0406347d76281 sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks: hooks:
- id: go-fmt - id: go-fmt
types: [go] types:
- id: gometalinter - go
types: [go] - id: gometalinter
types:
- go

@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add( ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048" GIT_TAG "master"
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here. # So, don't set these flags here.
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread)
LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math)
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")

@ -19,6 +19,8 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"os/signal"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -68,6 +70,20 @@ func main() {
store = &master.InMemStore{} store = &master.InMemStore{}
} }
shutdown := func() {
log.Infoln("shutting down gracefully")
err := store.Shutdown()
if err != nil {
log.Errorln(err)
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -84,8 +100,12 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
err = http.Serve(l, nil) go func() {
if err != nil { err = http.Serve(l, nil)
log.Fatal(err) if err != nil {
} log.Fatal(err)
}
}()
<-c
} }

@ -18,6 +18,8 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"os/signal"
"strconv" "strconv"
"time" "time"
@ -33,7 +35,8 @@ func main() {
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls") dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
@ -53,7 +56,7 @@ func main() {
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
idx, err = e.Register(*port) idx, err = e.Register(*port)
candy.Must(err) candy.Must(err)
@ -67,6 +70,20 @@ func main() {
} }
} }
shutdown := func() {
log.Infoln("shutting down gracefully")
sErr := e.Shutdown()
if sErr != nil {
log.Errorln(sErr)
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err) candy.Must(err)
@ -77,7 +94,11 @@ func main() {
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
candy.Must(err) candy.Must(err)
log.Infof("start pserver at port %d", *port) go func() {
err = http.Serve(l, nil) log.Infof("start pserver at port %d", *port)
candy.Must(err) err = http.Serve(l, nil)
candy.Must(err)
}()
<-c
} }

176
go/glide.lock generated

@ -1,15 +1,105 @@
hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855 hash: 2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c
updated: 2017-07-11T10:04:40.786745417+08:00 updated: 2017-07-29T07:34:48.722757905+08:00
imports: imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
subpackages:
- quantile
- name: github.com/boltdb/bolt
version: 583e8937c61f1af6513608ccc75c97b6abdf4ff9
- name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: cb2a496c4ddd1c87a9f280e116649b599999ec79 version: c31bec0f29facff13f7c3e3d948e55dd6689ed42
subpackages: subpackages:
- alarm
- auth
- auth/authpb - auth/authpb
- client
- clientv3 - clientv3
- clientv3/concurrency - clientv3/concurrency
- compactor
- discovery
- embed
- error
- etcdserver
- etcdserver/api
- etcdserver/api/v2http
- etcdserver/api/v2http/httptypes
- etcdserver/api/v3client
- etcdserver/api/v3election
- etcdserver/api/v3election/v3electionpb
- etcdserver/api/v3election/v3electionpb/gw
- etcdserver/api/v3lock
- etcdserver/api/v3lock/v3lockpb
- etcdserver/api/v3lock/v3lockpb/gw
- etcdserver/api/v3rpc
- etcdserver/api/v3rpc/rpctypes - etcdserver/api/v3rpc/rpctypes
- etcdserver/auth
- etcdserver/etcdserverpb - etcdserver/etcdserverpb
- etcdserver/etcdserverpb/gw
- etcdserver/membership
- etcdserver/stats
- lease
- lease/leasehttp
- lease/leasepb
- mvcc
- mvcc/backend
- mvcc/mvccpb - mvcc/mvccpb
- pkg/adt
- pkg/contention
- pkg/cors
- pkg/cpuutil
- pkg/crc
- pkg/debugutil
- pkg/fileutil
- pkg/httputil
- pkg/idutil
- pkg/ioutil
- pkg/logutil
- pkg/monotime
- pkg/netutil
- pkg/pathutil
- pkg/pbutil
- pkg/runtime
- pkg/schedule
- pkg/srv
- pkg/tlsutil
- pkg/transport
- pkg/types
- pkg/wait
- proxy/grpcproxy/adapter
- raft
- raft/raftpb
- rafthttp
- snap
- snap/snappb
- store
- version
- wal
- wal/walpb
- name: github.com/coreos/go-semver
version: 8ab6407b697782a06568d4b7f1db25550ec2e4c6
subpackages:
- semver
- name: github.com/coreos/go-systemd
version: 48702e0da86bd25e76cfef347e2adeb434a0d0a6
subpackages:
- daemon
- journal
- util
- name: github.com/coreos/pkg
version: 3ac0863d7acf3bc44daf49afef8919af12f704ef
subpackages:
- capnslog
- name: github.com/dgrijalva/jwt-go
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages:
- proto
- name: github.com/golang/protobuf - name: github.com/golang/protobuf
version: 4bd1920723d7b7c925de087aa32e2187708897f7 version: 4bd1920723d7b7c925de087aa32e2187708897f7
subpackages: subpackages:
@ -17,14 +107,61 @@ imports:
- proto - proto
- name: github.com/golang/snappy - name: github.com/golang/snappy
version: 553a641470496b2327abcac10b36396bd98e45c9 version: 553a641470496b2327abcac10b36396bd98e45c9
- name: github.com/google/btree
version: 925471ac9e2131377a91e1595defec898166fe49
- name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 6b7015e65d366bf3f19b2b2a000a831940f0f7e0
- name: github.com/grpc-ecosystem/grpc-gateway
version: 18d159699f2e83fc5bb9ef2f79465ca3f3122676
subpackages:
- runtime
- runtime/internal
- utilities
- name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages:
- pbutil
- name: github.com/namsral/flag - name: github.com/namsral/flag
version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04 version: 71ceffbeb0ba60fccc853971bb3ed4d7d90bfd04
- name: github.com/PaddlePaddle/recordio - name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129 version: 0432dee9fd4b24fb6840fb20a8c055b0c933fb81
- name: github.com/prometheus/client_golang
version: c5b7fccd204277076155f10851dad72b76a49317
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: 6f3806018612930941127f2a7c6c453ba2c527d2
subpackages:
- go
- name: github.com/prometheus/common
version: 49fee292b27bfff7f354ee0f64e1bc4850462edf
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1 version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy - name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: github.com/ugorji/go
version: ded73eae5db7e7a0ef6f55aace87a2873c5d2b74
subpackages:
- codec
- name: github.com/xiang90/probing
version: 07dd2e8dfe18522e9c447ba95f2fe95262f63bb2
- name: golang.org/x/crypto
version: 1351f936d976c60a0a48d728281922cf63eafb8d
repo: https://github.com/golang/crypto.git
vcs: git
subpackages:
- bcrypt
- blowfish
- name: golang.org/x/net - name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2 version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages: subpackages:
@ -36,11 +173,15 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: abf9c25f54453410d0c6668e519582a9e1115027 version: 0f826bdd13b500be0f1d4004938ad978fcc6031e
repo: https://github.com/golang/sys.git
vcs: git
subpackages: subpackages:
- unix - unix
- name: golang.org/x/text - name: golang.org/x/text
version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git
vcs: git
subpackages: subpackages:
- secure/bidirule - secure/bidirule
- transform - transform
@ -60,4 +201,23 @@ imports:
- stats - stats
- tap - tap
- transport - transport
testImports: [] - name: gopkg.in/yaml.v2
version: cd8b52f8269e0feb286dfeef29f8fe4d5b397e0b
testImports:
- name: github.com/davecgh/go-spew
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages:
- spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages:
- difflib
- name: github.com/stretchr/testify
version: 05e8a0eda380579888eb53c394909df027f06991
subpackages:
- assert

@ -6,8 +6,19 @@ import:
subpackages: subpackages:
- clientv3 - clientv3
- clientv3/concurrency - clientv3/concurrency
- embed
- etcdserver
- package: github.com/namsral/flag - package: github.com/namsral/flag
version: ^1.7.4-pre version: ^1.7.4-pre
- package: github.com/sirupsen/logrus - package: github.com/sirupsen/logrus
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy - package: github.com/topicai/candy
- package: golang.org/x/crypto
vcs: git
repo: https://github.com/golang/crypto.git
- package: golang.org/x/sys
vcs: git
repo: https://github.com/golang/sys.git
- package: golang.org/x/text
vcs: git
repo: https://github.com/golang/text.git

@ -18,7 +18,6 @@ package main
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <stdio.h> #include <stdio.h>
#define PADDLE_MASTER_OK 0 #define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1 #define PADDLE_MASTER_ERROR -1
@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove(client) remove(client)
} }
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset //export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int { func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client) c := get(client)
@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// paddle_next_record gets the nexts training record. // paddle_next_record gets the nexts training record.
// //
// returns number of bytes of the records if success, -1 if failed. // returns number of bytes of the records if success, -1 if failed, -2 if pass end.
// //
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
r, err := c.NextRecord() r, err := c.NextRecord()
if err != nil { if err != nil {
// Error // NOTE: use errors to indicate pass ends
// TODO: return the type of error? if err.Error() == master.ErrAllTaskFailed.Error() ||
err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil) *record = (*C.uchar)(nil)
return -1 return -1
} }

@ -16,7 +16,6 @@ package master
import ( import (
"os" "os"
"sync"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
@ -27,9 +26,9 @@ import (
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan record ch chan record
initChOnce sync.Once bufSize int
} }
type record struct { type record struct {
@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error {
if bufSize <= 0 { if bufSize <= 0 {
return nil return nil
} }
c.bufSize = bufSize
c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize)
go c.getRecords()
})
return nil return nil
} }
} }
@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
c.ch = make(chan record, c.bufSize)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time.Sleep(time.Second)
return c, nil return c, nil
} }
func (c *Client) getRecords() { // StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
}
func (c *Client) getRecords(passID int) {
for { for {
t, err := c.getTask() t, err := c.getTask(passID)
if err != nil { if err != nil {
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err) if err.Error() == ErrPassBefore.Error() ||
time.Sleep(3 * time.Second) err.Error() == ErrNoMoreAvailable.Error() ||
continue err.Error() == ErrAllTaskFailed.Error() {
c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue
}
log.Errorf("getTask error: %s", err)
} }
for _, chunk := range t.Chunks { for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path) f, e := os.Open(chunk.Path)
if err != nil { if e != nil {
log.Errorln(err) log.Errorln(e)
continue continue
} }
@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
} }
} }
// SetDataset set dataset for the master server to dispatch. // SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
// //
// SetDataset can be call multiple times from different nodes. But // After all tasks are done, another call of SetDataset will start another pass.
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error { func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil) err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
} }
// getTask gets a new task from the master server. // getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) { func (c *Client) getTask(passID int) (Task, error) {
var t Task var t Task
err := c.conn.Call("Service.GetTask", 0, &t) err := c.conn.Call("Service.GetTask", passID, &t)
return t, err return t, err
} }
@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
// thread-safe. // thread-safe.
func (c *Client) NextRecord() ([]byte, error) { func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})
r := <-c.ch r := <-c.ch
return r.r, r.err return r.r, r.err
} }

@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
server := rpc.NewServer() server := rpc.NewServer()
err = server.Register(s) sErr = server.Register(s)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
mux := http.NewServeMux() mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server) mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux) sErr = http.Serve(l, mux)
if err != nil { if sErr != nil {
panic(err) panic(sErr)
} }
}(l) }(l)
@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1) ch := make(chan string, 1)
ch <- addr ch <- addr
go c.monitorMaster(ch) go c.monitorMaster(ch)
err = c.SetDataset([]string{path}) err = c.SetDataset([]string{path})
if err != nil { if err != nil {
panic(err) panic(err)
@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) {
checkOnePass := func(i int) { checkOnePass := func(i int) {
var tasks []Task var tasks []Task
for idx := 0; idx < totalTask; idx++ { for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask() task, cErr := c.getTask(i)
if err != nil { if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("error: %v, pass: %d\n", cErr, i)
} }
tasks = append(tasks, task) tasks = append(tasks, task)
} }
_, err = c.getTask() // getting task before task finishes should return error
if err == nil { _, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i) t.Fatalf("Should get error, pass: %d\n", i)
} }
err = c.taskFinished(tasks[0].Meta.ID) cErr = c.taskFinished(tasks[0].Meta.ID)
if err != nil { if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", cErr, i)
} }
// call taskFailed once won't put the task to failed queue, just ensure
err = c.taskFailed(tasks[0].Meta) // the call
if err != nil { cErr = c.taskFailed(tasks[0].Meta)
t.Fatalf("Error: %v, pass: %d\n", err, i) if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
} }
tasks = tasks[1:] tasks = tasks[1:]
task, err := c.getTask() _, cErr = c.getTask(i)
if err != nil { if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatal(err) t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
} }
tasks = append(tasks, task)
for _, task := range tasks { for _, task := range tasks {
err = c.taskFinished(task.Meta.ID) cErr = c.taskFinished(task.Meta.ID)
if err != nil { if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatal(cErr)
} }
} }
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i) checkOnePass(i)
} }
} }

@ -20,8 +20,10 @@ import (
"net/http" "net/http"
"net/rpc" "net/rpc"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@ -29,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) { func TestNextRecord(t *testing.T) {
const ( const (
path = "/tmp/master_client_TestFull" path = "/tmp/master_client_TestFull"
@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
w := recordio.NewWriter(f, -1, -1) w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
_, err = w.Write([]byte{byte(i)}) _, err = w.Write([]byte{byte(i)})
if err != nil { if err != nil {
@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) {
panic(err) panic(err)
} }
c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10)) // start several client to test task fetching
if err != nil { var wg sync.WaitGroup
panic(err) for i := 0; i < 4; i++ {
} wg.Add(1)
// test for multiple concurrent clients
err = c.SetDataset([]string{path}) go func() {
if err != nil { defer wg.Done()
panic(err) // each go-routine needs a single client connection instance
} c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
for pass := 0; pass < 50; pass++ { t.Fatal(e)
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
} }
e = c.SetDataset([]string{path})
if len(r) != 1 { if e != nil {
t.Fatal(pass, i, "Length should be 1.", r) panic(e)
} }
// test for n passes
if received[r[0]] { for pass := 0; pass < 10; pass++ {
t.Fatal(pass, i, "Received duplicate.", received, r) c.StartGetRecords(pass)
received := make(map[byte]bool)
taskid := 0
for {
r, e := c.NextRecord()
if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
}
if len(r) != 1 {
t.Fatal(pass, taskid, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal(pass, taskid, "Received duplicate.", received, r)
}
taskid++
received[r[0]] = true
}
} }
received[r[0]] = true }()
}
} }
wg.Wait()
} }

@ -39,15 +39,12 @@ type EtcdClient struct {
statePath string statePath string
client *clientv3.Client client *clientv3.Client
lock *concurrency.Mutex lock *concurrency.Mutex
sess *concurrency.Session
} }
// NewEtcdClient creates a new EtcdClient. // NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints) log.Debugf("Connecting to etcd at %v", endpoints)
// TODO(helin): gracefully shutdown etcd store. Because etcd
// store holds a etcd lock, even though the lock will expire
// when the lease timeout, we need to implement graceful
// shutdown to release the lock.
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
DialTimeout: dialTimeout, DialTimeout: dialTimeout,
@ -67,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath) log.Infof("Trying to acquire lock at %s.", lockPath)
err = lock.Lock(context.TODO()) err = lock.Lock(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Debugf("Successfully acquired lock at %s.", lockPath) log.Infof("Successfully acquired lock at %s.", lockPath)
put := clientv3.OpPut(addrPath, addr) put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
@ -89,6 +86,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
statePath: statePath, statePath: statePath,
client: cli, client: cli,
lock: lock, lock: lock,
sess: sess,
} }
return e, nil return e, nil
@ -157,6 +155,21 @@ func (e *EtcdClient) Load() ([]byte, error) {
return state, nil return state, nil
} }
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
err := e.sess.Close()
newErr := e.client.Close()
if newErr != nil {
if err == nil {
err = newErr
} else {
log.Errorln(newErr)
}
}
return err
}
// GetKey gets the value by the specify key. // GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) { func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)

@ -40,3 +40,8 @@ func (m *InMemStore) Load() ([]byte, error) {
return m.buf, nil return m.buf, nil
} }
// Shutdown shuts down the in mem store.
func (m *InMemStore) Shutdown() error {
return nil
}

@ -19,6 +19,7 @@ import (
"compress/gzip" "compress/gzip"
"encoding/gob" "encoding/gob"
"errors" "errors"
"math/rand"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
@ -33,10 +34,23 @@ const (
dialTimeout = 5 * time.Second dialTimeout = 5 * time.Second
) )
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")
// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")
// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")
// Store is the interface for save and load the master state. // Store is the interface for save and load the master state.
type Store interface { type Store interface {
Save([]byte) error Save([]byte) error
Load() ([]byte, error) Load() ([]byte, error)
Shutdown() error
} }
// Chunk is a chunk of data consisted of several data instances. // Chunk is a chunk of data consisted of several data instances.
@ -75,17 +89,26 @@ type Service struct {
chunksPerTask int chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
failureMax int failureMax int
ready chan struct{}
store Store store Store
mu sync.Mutex ready chan struct{}
initDone bool initDone bool
taskQueues taskQueues
mu sync.Mutex
taskQueues taskQueues
currPass int
jobTasks []taskEntry
savingTrainer string savingTrainer string
} }
func partition(chunks []Chunk, chunksPerTask int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0 // generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart := rand.Int()
counter := 0
timestamp := time.Now().Nanosecond()
id := timestamp + randStart + counter
if chunksPerTask <= 0 { if chunksPerTask <= 0 {
chunksPerTask = 1 chunksPerTask = 1
} }
@ -95,7 +118,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
for i, c := range chunks { for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.Meta.ID = id cur.Task.Meta.ID = id
id++ counter++
id = timestamp + randStart + counter
result = append(result, cur) result = append(result, cur)
cur.Task.Chunks = nil cur.Task.Chunks = nil
} }
@ -266,19 +290,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
return err return err
} }
s.taskQueues.Todo = partition(chunks, s.chunksPerTask) s.jobTasks = partition(chunks, s.chunksPerTask)
s.taskQueues.Todo = s.jobTasks
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Errorln(err)
return err return err
} }
close(s.ready) close(s.ready)
s.initDone = true s.initDone = true
return nil return nil
} }
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func (s *Service) processFailedTask(t taskEntry, epoch int) { func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch { if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the // new epoch, task launched after the
@ -302,8 +328,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
return return
} }
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t) s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
} }
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
@ -331,37 +358,30 @@ func (s *Service) logFields() log.Fields {
} }
// GetTask gets a new task from the service. // GetTask gets a new task from the service.
func (s *Service) GetTask(_ int, task *Task) error { // passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
select { select {
case <-s.ready: case <-s.ready:
} }
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if passID < s.currPass {
return ErrPassBefore
}
if passID > s.currPass {
// Client may get run to pass after master when one client faster than the
// other
return ErrPassAfter
}
if len(s.taskQueues.Todo) == 0 { if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 { if len(s.taskQueues.Done) == 0 && len(s.taskQueues.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 { log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
err := errors.New("all task failed") return ErrAllTaskFailed
log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.WithFields(s.logFields()).Warningln("No more available task.")
return err
} }
s.taskQueues.Todo = s.taskQueues.Done log.WithFields(s.logFields()).Warningln("No more available task.")
s.taskQueues.Done = nil return ErrNoMoreAvailable
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
} }
t := s.taskQueues.Todo[0] t := s.taskQueues.Todo[0]
@ -381,7 +401,7 @@ func (s *Service) GetTask(_ int, task *Task) error {
} }
// TaskFinished tell the service that a task is finished. // TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, _ *int) error { func (s *Service) TaskFinished(taskID int, dummy *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }
@ -401,11 +421,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { // increase master side pass count if all tasks finished
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.") s.currPass++
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.taskQueues.Todo = s.jobTasks
s.taskQueues.Done = nil s.taskQueues.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.taskQueues.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.currPass)
} }
err := s.snapshot() err := s.snapshot()
@ -416,7 +439,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
} }
// TaskFailed tells the service that a task is failed. // TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, _ *int) error { func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }

@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100) cs := make([]Chunk, 100)
ts := partition(cs, 20) ts := partition(cs, 20)
for i := range ts { for i := range ts {
if ts[i].Task.Meta.ID != i { // test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i) t.Error(ts[i], i)
} }
} }

@ -0,0 +1,68 @@
package master_test
import (
"os"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert"
)
func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server
etcdDir, err := ioutils.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
select {
case <-e.Server.ReadyNotify():
t.Log("Server is ready!")
case <-time.After(60 * time.Second):
e.Server.Stop() // trigger a shutdown
t.Fatal("Server took too long to start!")
}
ep := []string{"127.0.0.1:2379"}
masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil {
t.Fatal(err)
}
_, err = master.NewService(store, 10, 10, 3)
if err != nil {
t.Fatal(err)
}
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: 3 * time.Second,
})
if err != nil {
t.Fatal(err)
}
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
if err != nil {
t.Fatal(err)
}
if err := cli.Close(); err != nil {
t.Fatal(err)
}
// test master process registry itself into etcd server.
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
}

@ -55,10 +55,10 @@ var curHandle C.paddle_pserver_client
func add(c *client.Client) C.paddle_pserver_client { func add(c *client.Client) C.paddle_pserver_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
client := curHandle cli := curHandle
curHandle++ curHandle++
handleMap[client] = c handleMap[cli] = c
return client return cli
} }
func get(client C.paddle_pserver_client) *client.Client { func get(client C.paddle_pserver_client) *client.Client {

@ -6,16 +6,19 @@ import cPickle as pickle
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379" etcd_endpoint = "http://" + etcd_ip + ":2379"
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
def cloud_reader(): def cloud_reader():
print "connecting to master, etcd endpoints: ", etcd_endpoint global master_client
master_client = master.client(etcd_endpoint, 5, 64)
master_client.set_dataset( master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"]) ["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30)
while 1: while 1:
r, e = master_client.next_record() r, e = master_client.next_record()
if not r: if not r:
if e != -2: # other errors
print "get record error:", e
break break
yield pickle.loads(r) yield pickle.loads(r)
@ -27,10 +30,12 @@ def main():
# network config # network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x, y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'), param_attr=paddle.attr.Param(
name='w', learning_rate=1e-3),
size=1, size=1,
act=paddle.activation.Linear(), act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b')) bias_attr=paddle.attr.Param(
name='b', learning_rate=1e-3))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y) cost = paddle.layer.mse_cost(input=y_predict, label=y)
@ -38,9 +43,8 @@ def main():
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
# create optimizer of new remote updater to pserver # create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0) optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
print "etcd endoint: ", etcd_endpoint
trainer = paddle.trainer.SGD(cost=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=optimizer, update_equation=optimizer,
@ -51,6 +55,8 @@ def main():
# event_handler to print training and testing info # event_handler to print training and testing info
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0: if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % ( print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost) event.pass_id, event.batch_id, event.cost)

@ -34,16 +34,19 @@ const (
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information // PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/" PsCheckpoint = "/checkpoints/"
retryTimeout = 5 * time.Second
) )
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination. // tolerance, service registry and coordination.
type EtcdClient struct { type EtcdClient struct {
numPservers int numPservers int
etcdEndpoints string endpoints string
etcdClient *clientv3.Client client *clientv3.Client
// etcdTimeout is also used as retry intervals. sess *concurrency.Session
etcdTimeout time.Duration dialTimeout time.Duration
ttlSec int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string externalIP string
// desired number of pservers in the job. // desired number of pservers in the job.
@ -52,11 +55,12 @@ type EtcdClient struct {
} }
// NewEtcdClient creates an EtcdClient // NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient { func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient {
return &EtcdClient{ return &EtcdClient{
etcdTimeout: timeout, dialTimeout: dialtimeout,
numPservers: numPservers, ttlSec: ttlSec,
etcdEndpoints: endpoints, numPservers: numPservers,
endpoints: endpoints,
} }
} }
@ -64,7 +68,6 @@ func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *Et
// //
// Register returns the index of the current pserver. // Register returns the index of the current pserver.
func (e *EtcdClient) Register(port int) (int, error) { func (e *EtcdClient) Register(port int) (int, error) {
var err error var err error
e.externalIP, err = networkhelper.GetExternalIP() e.externalIP, err = networkhelper.GetExternalIP()
if err != nil { if err != nil {
@ -72,19 +75,26 @@ func (e *EtcdClient) Register(port int) (int, error) {
} }
// initialize connection to etcd. // initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",") ep := strings.Split(e.endpoints, ",")
for { for {
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: ep, Endpoints: ep,
DialTimeout: e.etcdTimeout, DialTimeout: e.dialTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("connect to etcd error: %v", err) log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue
}
e.client = cli
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil {
log.Errorf("create etcd session error: %v", err)
time.Sleep(retryTimeout)
continue continue
} }
e.etcdClient = cli e.sess = sess
log.Debugf("inited client to %s", e.etcdEndpoints) log.Debugf("inited client to %s", e.endpoints)
break break
} }
// init /ps_desired using transaction, for multiple pservers may want to write // init /ps_desired using transaction, for multiple pservers may want to write
@ -95,7 +105,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
break break
@ -106,18 +116,18 @@ func (e *EtcdClient) Register(port int) (int, error) {
// wait and set s.desired init value // wait and set s.desired init value
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired) resp, err := e.client.Get(ctx, PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err) log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
if len(resp.Kvs) != 0 { if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err) log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change // NOTE: wait util ps_desired value change
continue continue
} }
@ -134,7 +144,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
time.Sleep(e.etcdTimeout) time.Sleep(retryTimeout)
continue continue
} }
break break
@ -144,10 +154,10 @@ func (e *EtcdClient) Register(port int) (int, error) {
} }
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { return concurrency.NewSTM(e.client, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired) dsStr := c.Get(PsDesired)
if dsStr == "" { if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers)) c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease()))
} }
return nil return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
@ -156,7 +166,7 @@ func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (
// registerPserverEtcd registers pserver node on etcd using transaction. // registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) { func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
var idx int var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { _, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error {
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
@ -165,26 +175,10 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" { if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info // find the first id and write info
pserverAddr := e.externalIP + ":" + strconv.Itoa(port) pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(resp.ID)) c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
log.Debugf("set pserver node %s with value %s", psKey, pserverAddr) log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished") log.Debug("register finished")
idx = i idx = i
registered = true registered = true
@ -207,7 +201,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
// GetKey gets the value by the specified key // GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.etcdClient.Get(ctx, key) resp, err := e.client.Get(ctx, key)
cancel() cancel()
if err != nil { if err != nil {
return []byte{}, err return []byte{}, err
@ -223,7 +217,27 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.etcdClient.Put(ctx, key, string(value)) _, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
cancel() cancel()
return err return err
} }
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
var err error
if e.sess != nil {
err = e.sess.Close()
}
if e.client != nil {
newErr := e.client.Close()
if newErr != nil {
if err != nil {
log.Errorln(newErr)
} else {
err = newErr
}
}
}
return err
}

@ -37,7 +37,7 @@ std::vector<std::string> Evaluator::getNames() const {
double Evaluator::getValue(const std::string name) const { double Evaluator::getValue(const std::string name) const {
paddle::Error err; paddle::Error err;
double v = m->rawPtr->getValue(name, &err); double v = m->rawPtr->getValue(name, &err);
if (err) { if (!err.isOK()) {
throw std::runtime_error(err.msg()); throw std::runtime_error(err.msg());
} }
return v; return v;

@ -3,7 +3,7 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory) cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
@ -29,7 +29,5 @@ py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init) add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto) cc_library(net SRCS net.cc DEPS op_registry)
# cc_library(net SRCS net.cc DEPS operator net_proto op_registry fc_op)
cc_library(net SRCS net.cc DEPS operator net_proto op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op) cc_test(net_op_test SRCS net_op_test.cc DEPS net add_op mul_op sigmoid_op softmax_op fc_op)

@ -0,0 +1,142 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/memory/memcpy.h"
namespace paddle {
namespace framework {
template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE(holder_->size() >= product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
}
template <typename T>
inline const T* Tensor::data() const {
check_memory_size<T>();
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
template <typename T>
inline T* Tensor::data() {
check_memory_size<T>();
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims);
return mutable_data<T>(place);
}
template <typename T>
inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
PADDLE_ENFORCE(product(dims_) > 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */
size_t size = product(dims_) * sizeof(T);
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size));
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(place)) {
holder_.reset(new PlaceholderImpl<T, platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size));
}
#endif
offset_ = 0;
}
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
template <typename T>
inline void Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size<T>();
*this = src;
}
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::Place& dst_place) {
src.check_memory_size<T>();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = static_cast<const void*>(src.data<T>());
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size, 0);
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
}
#endif
}
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
check_memory_size<T>();
PADDLE_ENFORCE(begin_idx >= 0, "Slice begin index is less than zero.");
PADDLE_ENFORCE(end_idx <= dims_[0], "Slice end index is out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
int base = product(dims_) / dims_[0];
Tensor dst;
dst.holder_ = holder_;
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
dst.offset_ = offset_ + begin_idx * base * sizeof(T);
return dst;
}
inline void Tensor::Resize(const DDim& dims) { dims_ = dims; }
inline const DDim& Tensor::dims() const { return dims_; }
} // namespace framework
} // namespace paddle

@ -20,17 +20,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps) { void NetOp::CompleteAddOp(bool calc) {
auto grad_ops = std::make_shared<PlainNet>();
for (auto& op : ForwardOps->ops_) {
auto op_grad = OpRegistry::CreateGradOp(op);
grad_ops->AddOp(op_grad);
}
grad_ops->CompleteAddOp();
return grad_ops;
}
void PlainNet::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
if (!calc) return; if (!calc) return;
std::unordered_set<std::string> input_set; std::unordered_set<std::string> input_set;
@ -70,7 +60,7 @@ void PlainNet::CompleteAddOp(bool calc) {
attrs_["temporary_index"] = tmp_index; attrs_["temporary_index"] = tmp_index;
} }
std::string PlainNet::DebugString() const { std::string NetOp::DebugString() const {
std::ostringstream os; std::ostringstream os;
os << OperatorBase::DebugString() << std::endl; os << OperatorBase::DebugString() << std::endl;
for (auto& op : ops_) { for (auto& op : ops_) {
@ -82,5 +72,7 @@ std::string PlainNet::DebugString() const {
return os.str(); return os.str();
} }
bool NetOp::IsNetOp() const { return true; }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -37,21 +37,7 @@ namespace framework {
* This is the base class of network, all the networks should implement the APIs * This is the base class of network, all the networks should implement the APIs
* it defines. * it defines.
*/ */
class Net : public OperatorBase { class NetOp : public OperatorBase {
public:
virtual void AddOp(const std::shared_ptr<OperatorBase>& op) = 0;
virtual void CompleteAddOp(bool calc) = 0;
};
using NetPtr = std::shared_ptr<Net>;
/**
* @brief a basic implementation of Net.
*
* PlainNet is a very simple Net, it create a list of operators, and run them
* sequentially following the order they added.
*/
class PlainNet : public Net {
public: public:
/** /**
* Infer all the operators' input and output variables' shapes, will be called * Infer all the operators' input and output variables' shapes, will be called
@ -80,15 +66,17 @@ class PlainNet : public Net {
/** /**
* @brief Add an operator by ptr * @brief Add an operator by ptr
*/ */
void AddOp(const std::shared_ptr<OperatorBase>& op) override { void AddOp(const std::shared_ptr<OperatorBase>& op) {
PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed");
ops_.push_back(op); ops_.push_back(op);
} }
void CompleteAddOp(bool calculate = true) override; void CompleteAddOp(bool calculate = true);
std::string DebugString() const override; std::string DebugString() const override;
bool IsNetOp() const override;
std::vector<std::shared_ptr<OperatorBase>> ops_; std::vector<std::shared_ptr<OperatorBase>> ops_;
private: private:
@ -100,7 +88,5 @@ class PlainNet : public Net {
} }
}; };
std::shared_ptr<PlainNet> AddBackwardOp(std::shared_ptr<PlainNet> ForwardOps);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -40,7 +40,7 @@ void AssertSameVectorWithoutOrder(const std::vector<T>& expected,
} }
TEST(OpKernel, all) { TEST(OpKernel, all) {
auto net = std::make_shared<PlainNet>(); auto net = std::make_shared<NetOp>();
ASSERT_NE(net, nullptr); ASSERT_NE(net, nullptr);
auto op1 = std::make_shared<TestOp>(); auto op1 = std::make_shared<TestOp>();
@ -69,30 +69,23 @@ TEST(OpKernel, all) {
net->Run(scope, dev_ctx); net->Run(scope, dev_ctx);
ASSERT_EQ(2, infer_shape_cnt); ASSERT_EQ(2, infer_shape_cnt);
ASSERT_EQ(2, run_cnt); ASSERT_EQ(2, run_cnt);
ASSERT_THROW(net->AddOp(op2), std::runtime_error); ASSERT_THROW(net->AddOp(op2), paddle::platform::EnforceNotMet);
}
TEST(AddBackwardOp, TestGradOp) {
auto net = std::make_shared<PlainNet>();
ASSERT_NE(net, nullptr);
net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
net->AddOp(
framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""}, {}));
auto grad_ops = AddBackwardOp(net);
for (auto& op : grad_ops->ops_) {
op->DebugString();
}
} }
// TODO(zhihong): add fc grad without registering. //! TODO(yuyang18): Refine Backward Op.
// TEST(AddBackwardOp, TestNoGradOp) { // TEST(AddBackwardOp, TestGradOp) {
// auto net = std::make_shared<PlainNet>(); // auto net = std::make_shared<NetOp>();
// ASSERT_NE(net, nullptr); // ASSERT_NE(net, nullptr);
// net->AddOp(framework::OpRegistry::CreateOp("fc", {"X", "W", "b"}, {"Y"}, // net->AddOp(framework::OpRegistry::CreateOp("mul", {"X", "Y"}, {"Out"}, {}));
// {})); auto grad_ops = AddBackwardOp(net); for (auto& op : grad_ops->ops_) { // net->AddOp(
// op->DebugString(); // framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {"Out"}, {}));
// } // net->AddOp(framework::OpRegistry::CreateOp("add_two", {"X", "Y"}, {""},
// } // {}));
// auto grad_ops = AddBackwardOp(net);
// for (auto& op : grad_ops->ops_) {
// op->DebugString();
// }
//}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save