Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/imperative

simple rnn
revert-15470-feature/imperative
JiabinYang 6 years ago
commit 266e0b63cd

@ -4,7 +4,6 @@ cache:
- $HOME/.ccache
- $HOME/.cache/pip
- $TRAVIS_BUILD_DIR/build/third_party
- $TRAVIS_BUILD_DIR/build_android/third_party
sudo: required
dist: trusty
services:
@ -13,7 +12,6 @@ os:
- linux
env:
- JOB=check_style
- JOB=build_android
addons:
ssh_known_hosts: 13.229.163.131
before_install:

@ -279,9 +279,6 @@ include(inference_lib) # add paddle fluid inference libraries
include_directories("${PADDLE_SOURCE_DIR}")
include_directories("${PADDLE_SOURCE_DIR}/paddle/legacy/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c")
set(EXTERNAL_LIBS
gflags
@ -320,21 +317,6 @@ if(USE_NNPACK)
list(APPEND EXTERNAL_LIBS ${NNPACK_LIBS})
endif(USE_NNPACK)
add_subdirectory(proto)
if(NOT MOBILE_INFERENCE AND NOT WITH_FLUID_ONLY)
# "add_subdirectory(go)" should be placed after the following loine,
# because it depends on paddle/optimizer.
add_subdirectory(paddle/legacy/optimizer)
endif()
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
# placed after this block, because they depends on it.
if(WITH_GOLANG)
enable_language(Go)
add_subdirectory(go)
endif(WITH_GOLANG)
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")

@ -748,7 +748,7 @@ function(grpc_library TARGET_NAME)
#FIXME(putcn): the follwoing line is supposed to generate *.pb.h and cc, but
# somehow it didn't. line 602 to 604 is to patching this. Leaving this here
# for now to enable dist CI.
protobuf_generate_cpp(grpc_proto_srcs grpc_proto_hdrs "${ABS_PROTO}")
paddle_protobuf_generate_cpp(grpc_proto_srcs grpc_proto_hdrs "${ABS_PROTO}")
set(grpc_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.cc")
set(grpc_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/${PROTO_WE}.grpc.pb.h")
cc_library("${TARGET_NAME}_proto" SRCS "${grpc_proto_srcs}")
@ -791,7 +791,7 @@ function(brpc_library TARGET_NAME)
get_filename_component(PROTO_WE ${brpc_library_PROTO} NAME_WE)
get_filename_component(PROTO_PATH ${ABS_PROTO} PATH)
protobuf_generate_cpp(brpc_proto_srcs brpc_proto_hdrs "${ABS_PROTO}")
paddle_protobuf_generate_cpp(brpc_proto_srcs brpc_proto_hdrs "${ABS_PROTO}")
cc_library("${TARGET_NAME}_proto" SRCS "${brpc_proto_srcs}")
cc_library("${TARGET_NAME}" SRCS "${brpc_library_SRCS}" DEPS "${TARGET_NAME}_proto" "${brpc_library_DEPS}")
endfunction()

3
go/.gitignore vendored

@ -1,3 +0,0 @@
vendor/
.glide/
proto/*.go

@ -1,23 +0,0 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master)
add_subdirectory(master/c)
add_subdirectory(master)
add_subdirectory(pserver)
add_subdirectory(pserver/client)
add_subdirectory(utils/networkhelper)

@ -1,15 +0,0 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
go_binary(master SRC master.go)

@ -1,120 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"os/signal"
"strconv"
"strings"
"time"
log "github.com/inconshreveable/log15"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
)
func main() {
port := flag.Int("port", 8080, "port of the master server.")
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warn, error, crit")
flag.Parse()
lvl, err := log.LvlFromString(*logLevel)
if err != nil {
panic(err)
}
log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
if *endpoints == "" {
log.Warn("-endpoints not set, fault tolerance not be enabled.")
}
var store master.Store
if *endpoints != "" {
eps := strings.Split(*endpoints, ",")
ip, err := networkhelper.GetExternalIP()
if err != nil {
log.Crit("get external ip error", log.Ctx{"error": err})
panic(err)
}
addr := fmt.Sprintf("%s:%d", ip, *port)
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
if err != nil {
log.Crit("error creating etcd client.", log.Ctx{"error": err})
panic(err)
}
} else {
store = &master.InMemStore{}
}
shutdown := func() {
log.Info("shutting down gracefully")
err := store.Shutdown()
if err != nil {
log.Error("shutdown error", log.Ctx{"error": err})
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil {
log.Crit("error creating new service.", log.Ctx{"error": err})
panic(err)
}
err = rpc.Register(s)
if err != nil {
log.Crit("error registering to etcd.", log.Ctx{"error": err})
panic(err)
}
rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
panic(err)
}
go func() {
err = http.Serve(l, nil)
if err != nil {
log.Crit("error serving HTTP", log.Ctx{"error": err})
panic(err)
}
}()
<-c
}

@ -1 +0,0 @@
pserver

@ -1,15 +0,0 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
go_binary(pserver SRCS pserver.go DEPS paddle_go_optimizer)

@ -1,108 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
import (
"net"
"net/http"
"net/rpc"
"os"
"os/signal"
"strconv"
"time"
"github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/inconshreveable/log15"
)
func main() {
port := flag.Int("port", 8001, "port of the pserver")
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warn, error, crit")
flag.Parse()
lvl, err := log.LvlFromString(*logLevel)
if err != nil {
panic(err)
}
log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
var idx int
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 {
idx = *index
} else {
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
idx, err = e.Register(*port)
candy.Must(err)
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
log.Info("load checkpoint error", "error", err)
} else {
panic(err)
}
}
}
shutdown := func() {
log.Info("shutting down gracefully")
sErr := e.Shutdown()
if sErr != nil {
log.Error("error shutting down", log.Ctx{"error": sErr})
}
}
// Guaranteed to run even panic happens.
defer shutdown()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
candy.Must(err)
err = rpc.Register(s)
candy.Must(err)
rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
candy.Must(err)
go func() {
log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil)
candy.Must(err)
}()
<-c
}

@ -1,120 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package connection
import (
"errors"
"net/rpc"
"sync"
log "github.com/sirupsen/logrus"
)
// TODO(helin): add TCP re-connect logic
// Conn is a connection to a parameter server
type Conn struct {
mu sync.Mutex
client *rpc.Client
waitConn chan struct{}
}
// New creates a new connection.
func New() *Conn {
c := &Conn{}
return c
}
// Close closes the connection.
func (c *Conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return nil
}
return c.client.Close()
}
// Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error {
c.mu.Lock()
if c.client != nil {
err := c.client.Close()
if err != nil {
c.mu.Unlock()
return err
}
c.client = nil
}
c.mu.Unlock()
client, err := rpc.DialHTTP("tcp", addr)
if err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
c.client = client
if c.waitConn != nil {
close(c.waitConn)
c.waitConn = nil
}
} else {
err := client.Close()
if err != nil {
log.Errorln(err)
}
return errors.New("client already set from a concurrent goroutine")
}
return nil
}
// TODO(helin): refactor Call to be able to perform given retry
// policy.
// Call make a RPC call.
//
// Call will be blocked until the connection to remote RPC service
// being established.
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
c.mu.Lock()
client := c.client
var waitCh chan struct{}
if client == nil {
if c.waitConn != nil {
waitCh = c.waitConn
} else {
waitCh = make(chan struct{})
c.waitConn = waitCh
}
}
c.mu.Unlock()
if waitCh != nil {
// wait until new connection being established
<-waitCh
return c.Call(serviceMethod, args, reply)
}
return client.Call(serviceMethod, args, reply)
}

@ -1,17 +0,0 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(master_test)
endif()

@ -1,15 +0,0 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
go_library(paddle_master SHARED DEPS paddle_go_optimizer)

@ -1,196 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main
/*
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
#define PADDLE_SAVE_MODEL_OK 1
#define PADDLE_SAVE_MODEL_SKIP 0
typedef int paddle_master_client;
*/
import "C"
import (
"strings"
"sync"
"time"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
log "github.com/inconshreveable/log15"
)
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
func add(c *master.Client) C.paddle_master_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}
func get(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}
func remove(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}
//export paddle_new_etcd_master_client
//
// bufSize is the record buffer size.
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints)
endpoints := strings.Split(p, ",")
c, err := master.NewClient(
master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
master.WithBuffer(bufSize),
)
if err != nil {
panic(err)
}
return add(c)
}
//export paddle_new_master_client
//
// bufSize is the record buffer size.
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
if err != nil {
panic(err)
}
return add(c)
}
//export paddle_release_master_client
func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
if err != nil {
log.Error("error set dataset",
log.Ctx{"error": err, "paths": paths})
return C.PADDLE_MASTER_ERROR
}
return C.PADDLE_MASTER_OK
}
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
//
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r, err := c.NextRecord()
if err != nil {
// NOTE: use errors to indicate pass ends
if err.Error() == master.ErrAllTaskFailed.Error() ||
err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1
}
if len(r) == 0 {
// Empty record
*record = (*C.uchar)(nil)
return 0
}
size := C.size_t(len(r))
*record = (*C.uchar)(C.malloc(size))
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
return C.int(size)
}
// paddle_request_save_model requests the master server to approve the
// caller to save the model.
//
// returns 1 if the save the model request is approved, 0 if the
// request is rejected because other trainer is saving the model, -1
// if error happened.
//
//export paddle_request_save_model
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil {
log.Error("error request save model", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR
}
if need {
return C.PADDLE_SAVE_MODEL_OK
}
return C.PADDLE_SAVE_MODEL_SKIP
}
//export mem_free
func mem_free(p unsafe.Pointer) {
// "free" may be a better name for this function, but doing so
// will cause calling any function of this library from Python
// ctypes hanging.
C.free(p)
}
func main() {}

File diff suppressed because it is too large Load Diff

@ -1,152 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
const (
totalTask = 20
chunkPerTask = 10
)
func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0"
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
go func(l net.Listener) {
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if sErr != nil {
panic(sErr)
}
server := rpc.NewServer()
sErr = server.Register(s)
if sErr != nil {
panic(sErr)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
sErr = http.Serve(l, mux)
if sErr != nil {
panic(sErr)
}
}(l)
f, err := os.Create(path)
if err != nil {
panic(err)
}
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
_, err = w.Write(nil)
if err != nil {
panic(err)
}
// call Close to force RecordIO writing a chunk.
err = w.Close()
if err != nil {
panic(err)
}
}
err = f.Close()
if err != nil {
panic(err)
}
// Manually intialize client to avoid calling c.getRecords()
c := &Client{}
c.conn = connection.New()
addr := fmt.Sprintf(":%d", p)
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, cErr := c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("error: %v, pass: %d\n", cErr, i)
}
tasks = append(tasks, task)
}
// getting task before task finishes should return error
_, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}
cErr = c.taskFinished(tasks[0].Meta.ID)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr = c.taskFailed(tasks[0].Meta)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
tasks = tasks[1:]
_, cErr = c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
}
for _, task := range tasks {
cErr = c.taskFinished(task.Meta.ID)
if cErr != nil {
t.Fatal(cErr)
}
}
}
for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i)
}
}

@ -1,150 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master_test
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) {
const (
path = "/tmp/master_client_TestFull"
total = 50
)
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil {
panic(err)
}
server := rpc.NewServer()
err = server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
f, err := os.Create(path)
if err != nil {
panic(err)
}
w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ {
_, err = w.Write([]byte{byte(i)})
if err != nil {
panic(err)
}
}
err = w.Close()
if err != nil {
panic(err)
}
err = f.Close()
if err != nil {
panic(err)
}
// start several client to test task fetching
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
// test for multiple concurrent clients
go func() {
defer wg.Done()
// each go-routine needs a single client connection instance
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
t.Fatal(e)
}
e = c.SetDataset([]string{path})
if e != nil {
panic(e)
}
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
received := make(map[byte]bool)
taskid := 0
for {
r, e := c.NextRecord()
if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
}
if len(r) != 1 {
t.Fatal(pass, taskid, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal(pass, taskid, "Received duplicate.", received, r)
}
taskid++
received[r[0]] = true
}
}
}()
}
wg.Wait()
}

@ -1,201 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import (
"context"
"time"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/inconshreveable/log15"
)
const (
// DefaultLockPath is the default etcd master lock path.
DefaultLockPath = "/master/lock"
// DefaultStatePath is the default etcd key for master state.
DefaultStatePath = "/master/state"
// DefaultAddrPath is the default etcd key for master address.
DefaultAddrPath = "/master/addr"
)
// EtcdClient is the etcd client that the master uses for fault
// tolerance and service registry.
type EtcdClient struct {
lockPath string
statePath string
client *clientv3.Client
lock *concurrency.Mutex
sess *concurrency.Session
}
// NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints,
DialTimeout: dialTimeout,
})
if err != nil {
return nil, err
}
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec))
if err != nil {
return nil, err
}
lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// software will kill one of them.
log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
err = lock.Lock(context.TODO())
if err != nil {
return nil, err
}
log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
if err != nil {
return nil, err
}
if !resp.Succeeded {
log.Crit("No longer owns the master lock. Exiting.")
panic("No longer owns the master lock. Exiting.")
}
e := &EtcdClient{
lockPath: lockPath,
statePath: statePath,
client: cli,
lock: lock,
sess: sess,
}
return e, nil
}
// Save saves the state into the etcd.
func (e *EtcdClient) Save(state []byte) error {
ctx := context.TODO()
put := clientv3.OpPut(e.statePath, string(state))
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
if err != nil {
return err
}
if !resp.Succeeded {
log.Error("No longer owns the lock, trying to lock again")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := e.lock.Lock(ctx)
cancel()
if err != nil {
// We lost the master lock and can not acquire
// it back, it means some other master is
// already started. We don't want cluster
// management system to kill the master server
// who is holding the lock and running
// correctly. So the most feasible solution is
// to kill current master server. The current
// state is not saved, but the trainer's RPC
// call will fail, so the trainer will retry.
log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
panic("Could not acquire the lock at %s: %v. Exiting.")
}
log.Info("Successfully acquired lock at %s.", e.lockPath)
return e.Save(state)
}
return nil
}
// Load loads the state from etcd.
func (e *EtcdClient) Load() ([]byte, error) {
ctx := context.TODO()
get := clientv3.OpGet(e.statePath)
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit()
if err != nil {
return nil, err
}
if !resp.Succeeded {
log.Error("No longer owns the lock, trying to lock and load again.")
err = e.lock.Lock(context.Background())
if err != nil {
return nil, err
}
return e.Load()
}
kvs := resp.Responses[0].GetResponseRange().Kvs
if len(kvs) == 0 {
// No state exists
return nil, nil
}
state := kvs[0].Value
return state, nil
}
// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
err := e.sess.Close()
newErr := e.client.Close()
if newErr != nil {
if err == nil {
err = newErr
} else {
log.Error("shutdown error", log.Ctx{"error": newErr})
}
}
return err
}
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
return "", err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return "", nil
}
v := kvs[0].Value
return string(v), nil
}
// watchKey watches the specify key and send to valChan if there is some event.
func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
valChan <- string(ev.Kv.Value)
}
}
}

@ -1,47 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import "sync"
// InMemStore is an in memory implementation of Store interface.
//
// It does not tolerate the fault that causes the program to crash.
type InMemStore struct {
mu sync.Mutex
buf []byte
}
// Save saves the state into the in-memory store.
func (m *InMemStore) Save(state []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
m.buf = state
return nil
}
// Load loads the state from the in-memory store.
func (m *InMemStore) Load() ([]byte, error) {
m.mu.Lock()
defer m.mu.Unlock()
return m.buf, nil
}
// Shutdown shuts down the in mem store.
func (m *InMemStore) Shutdown() error {
return nil
}

File diff suppressed because it is too large Load Diff

@ -1,52 +0,0 @@
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master
import "testing"
func TestPartitionCount(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 5)
if len(ts) != 20 {
t.Error(len(ts))
}
cs = make([]Chunk, 101)
ts = partition(cs, 5)
if len(ts) != 21 {
t.Error(len(ts))
}
ts = partition(cs, 1)
if len(ts) != 101 {
t.Error(len(ts))
}
ts = partition(cs, 0)
if len(ts) != 101 {
t.Error(len(ts))
}
}
func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
// test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i)
}
}
}

@ -1,72 +0,0 @@
package master_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed"
"github.com/stretchr/testify/assert"
)
func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
ep := []string{endpoint}
masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil {
t.Fatal(err)
}
_, err = master.NewService(store, 10, 10, 3)
if err != nil {
t.Fatal(err)
}
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: 3 * time.Second,
})
if err != nil {
t.Fatal(err)
}
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
if err != nil {
t.Fatal(err)
}
if err := cli.Close(); err != nil {
t.Fatal(err)
}
// test master process registry itself into etcd server.
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
}

@ -1,4 +0,0 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

@ -1,17 +0,0 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer gen_proto_go)
endif()

@ -1,17 +0,0 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer)
endif()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save