Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/imperative
simple rnnrevert-15470-feature/imperative
commit
266e0b63cd
@ -1,3 +0,0 @@
|
||||
vendor/
|
||||
.glide/
|
||||
proto/*.go
|
@ -1,23 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
add_subdirectory(pserver/client/c)
|
||||
add_subdirectory(cmd/pserver)
|
||||
add_subdirectory(cmd/master)
|
||||
add_subdirectory(master/c)
|
||||
add_subdirectory(master)
|
||||
add_subdirectory(pserver)
|
||||
add_subdirectory(pserver/client)
|
||||
add_subdirectory(utils/networkhelper)
|
@ -1,15 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
go_binary(master SRC master.go)
|
@ -1,120 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/inconshreveable/log15"
|
||||
"github.com/namsral/flag"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 8080, "port of the master server.")
|
||||
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
|
||||
endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
|
||||
taskTimeoutDur := flag.Duration("task-timout-dur", 20*time.Minute, "task timout duration.")
|
||||
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
|
||||
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
|
||||
logLevel := flag.String("log-level", "info",
|
||||
"log level, possible values: debug, info, warn, error, crit")
|
||||
flag.Parse()
|
||||
|
||||
lvl, err := log.LvlFromString(*logLevel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
log.Root().SetHandler(
|
||||
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
|
||||
)
|
||||
|
||||
if *endpoints == "" {
|
||||
log.Warn("-endpoints not set, fault tolerance not be enabled.")
|
||||
}
|
||||
|
||||
var store master.Store
|
||||
if *endpoints != "" {
|
||||
eps := strings.Split(*endpoints, ",")
|
||||
ip, err := networkhelper.GetExternalIP()
|
||||
if err != nil {
|
||||
log.Crit("get external ip error", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", ip, *port)
|
||||
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
|
||||
if err != nil {
|
||||
log.Crit("error creating etcd client.", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
store = &master.InMemStore{}
|
||||
}
|
||||
|
||||
shutdown := func() {
|
||||
log.Info("shutting down gracefully")
|
||||
err := store.Shutdown()
|
||||
if err != nil {
|
||||
log.Error("shutdown error", log.Ctx{"error": err})
|
||||
}
|
||||
}
|
||||
|
||||
// Guaranteed to run even panic happens.
|
||||
defer shutdown()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
|
||||
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
|
||||
if err != nil {
|
||||
log.Crit("error creating new service.", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = rpc.Register(s)
|
||||
if err != nil {
|
||||
log.Crit("error registering to etcd.", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
|
||||
rpc.HandleHTTP()
|
||||
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
||||
if err != nil {
|
||||
log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
|
||||
panic(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
err = http.Serve(l, nil)
|
||||
if err != nil {
|
||||
log.Crit("error serving HTTP", log.Ctx{"error": err})
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-c
|
||||
}
|
@ -1 +0,0 @@
|
||||
pserver
|
@ -1,15 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
go_binary(pserver SRCS pserver.go DEPS paddle_go_optimizer)
|
@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/namsral/flag"
|
||||
"github.com/topicai/candy"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/pserver"
|
||||
log "github.com/inconshreveable/log15"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 8001, "port of the pserver")
|
||||
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
|
||||
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
|
||||
"comma separated endpoint string for pserver to connect to etcd")
|
||||
dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
|
||||
etcdTTL := flag.Int("etcd-ttl", 5, "etcd time to live in seconds")
|
||||
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
|
||||
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
|
||||
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
|
||||
logLevel := flag.String("log-level", "info",
|
||||
"log level, possible values: debug, info, warn, error, crit")
|
||||
flag.Parse()
|
||||
|
||||
lvl, err := log.LvlFromString(*logLevel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
log.Root().SetHandler(
|
||||
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
|
||||
)
|
||||
|
||||
var idx int
|
||||
|
||||
var cp pserver.Checkpoint
|
||||
var e *pserver.EtcdClient
|
||||
if *index >= 0 {
|
||||
idx = *index
|
||||
} else {
|
||||
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *dialTimeout, *etcdTTL)
|
||||
idx, err = e.Register(*port)
|
||||
candy.Must(err)
|
||||
|
||||
cp, err = pserver.LoadCheckpoint(e, idx)
|
||||
if err != nil {
|
||||
if err == pserver.ErrCheckpointNotFound {
|
||||
log.Info("load checkpoint error", "error", err)
|
||||
} else {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shutdown := func() {
|
||||
log.Info("shutting down gracefully")
|
||||
sErr := e.Shutdown()
|
||||
if sErr != nil {
|
||||
log.Error("error shutting down", log.Ctx{"error": sErr})
|
||||
}
|
||||
}
|
||||
|
||||
// Guaranteed to run even panic happens.
|
||||
defer shutdown()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
|
||||
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
|
||||
candy.Must(err)
|
||||
|
||||
err = rpc.Register(s)
|
||||
candy.Must(err)
|
||||
|
||||
rpc.HandleHTTP()
|
||||
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
||||
candy.Must(err)
|
||||
|
||||
go func() {
|
||||
log.Info("serving pserver", log.Ctx{"port": *port})
|
||||
err = http.Serve(l, nil)
|
||||
candy.Must(err)
|
||||
}()
|
||||
|
||||
<-c
|
||||
}
|
@ -1,120 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package connection
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/rpc"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// TODO(helin): add TCP re-connect logic
|
||||
|
||||
// Conn is a connection to a parameter server
|
||||
type Conn struct {
|
||||
mu sync.Mutex
|
||||
client *rpc.Client
|
||||
waitConn chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new connection.
|
||||
func New() *Conn {
|
||||
c := &Conn{}
|
||||
return c
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// Connect connects the connection to a address.
|
||||
func (c *Conn) Connect(addr string) error {
|
||||
c.mu.Lock()
|
||||
if c.client != nil {
|
||||
err := c.client.Close()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
c.client = nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
client, err := rpc.DialHTTP("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.client == nil {
|
||||
c.client = client
|
||||
if c.waitConn != nil {
|
||||
close(c.waitConn)
|
||||
c.waitConn = nil
|
||||
}
|
||||
} else {
|
||||
err := client.Close()
|
||||
if err != nil {
|
||||
log.Errorln(err)
|
||||
}
|
||||
|
||||
return errors.New("client already set from a concurrent goroutine")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(helin): refactor Call to be able to perform given retry
|
||||
// policy.
|
||||
|
||||
// Call make a RPC call.
|
||||
//
|
||||
// Call will be blocked until the connection to remote RPC service
|
||||
// being established.
|
||||
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
|
||||
c.mu.Lock()
|
||||
client := c.client
|
||||
var waitCh chan struct{}
|
||||
if client == nil {
|
||||
if c.waitConn != nil {
|
||||
waitCh = c.waitConn
|
||||
} else {
|
||||
waitCh = make(chan struct{})
|
||||
c.waitConn = waitCh
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if waitCh != nil {
|
||||
// wait until new connection being established
|
||||
<-waitCh
|
||||
return c.Call(serviceMethod, args, reply)
|
||||
}
|
||||
|
||||
return client.Call(serviceMethod, args, reply)
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
if(WITH_TESTING)
|
||||
go_test(master_test)
|
||||
endif()
|
@ -1,15 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
go_library(paddle_master SHARED DEPS paddle_go_optimizer)
|
@ -1,196 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package main
|
||||
|
||||
/*
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <stdio.h>
|
||||
#define PADDLE_MASTER_OK 0
|
||||
#define PADDLE_MASTER_ERROR -1
|
||||
|
||||
#define PADDLE_SAVE_MODEL_OK 1
|
||||
#define PADDLE_SAVE_MODEL_SKIP 0
|
||||
|
||||
typedef int paddle_master_client;
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
log "github.com/inconshreveable/log15"
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
var handleMap = make(map[C.paddle_master_client]*master.Client)
|
||||
var curHandle C.paddle_master_client
|
||||
|
||||
func init() {
|
||||
log.Root().SetHandler(
|
||||
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
|
||||
)
|
||||
}
|
||||
|
||||
func add(c *master.Client) C.paddle_master_client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
client := curHandle
|
||||
curHandle++
|
||||
handleMap[client] = c
|
||||
return client
|
||||
}
|
||||
|
||||
func get(client C.paddle_master_client) *master.Client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return handleMap[client]
|
||||
}
|
||||
|
||||
func remove(client C.paddle_master_client) *master.Client {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
h := handleMap[client]
|
||||
delete(handleMap, client)
|
||||
return h
|
||||
}
|
||||
|
||||
//export paddle_new_etcd_master_client
|
||||
//
|
||||
// bufSize is the record buffer size.
|
||||
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
|
||||
p := C.GoString(etcdEndpoints)
|
||||
endpoints := strings.Split(p, ",")
|
||||
c, err := master.NewClient(
|
||||
master.WithEtcd(endpoints, time.Duration(timeout)*time.Second),
|
||||
master.WithBuffer(bufSize),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return add(c)
|
||||
}
|
||||
|
||||
//export paddle_new_master_client
|
||||
//
|
||||
// bufSize is the record buffer size.
|
||||
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
|
||||
a := C.GoString(addr)
|
||||
c, err := master.NewClient(master.WithAddr(a), master.WithBuffer(bufSize))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return add(c)
|
||||
}
|
||||
|
||||
//export paddle_release_master_client
|
||||
func paddle_release_master_client(client C.paddle_master_client) {
|
||||
remove(client)
|
||||
}
|
||||
|
||||
//export paddle_start_get_records
|
||||
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
|
||||
c := get(client)
|
||||
c.StartGetRecords(int(pass))
|
||||
}
|
||||
|
||||
//export paddle_set_dataset
|
||||
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
|
||||
c := get(client)
|
||||
var paths []string
|
||||
for i := 0; i < int(size); i++ {
|
||||
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
|
||||
str := C.GoString(*ptr)
|
||||
paths = append(paths, str)
|
||||
}
|
||||
err := c.SetDataset(paths)
|
||||
if err != nil {
|
||||
log.Error("error set dataset",
|
||||
log.Ctx{"error": err, "paths": paths})
|
||||
return C.PADDLE_MASTER_ERROR
|
||||
}
|
||||
|
||||
return C.PADDLE_MASTER_OK
|
||||
}
|
||||
|
||||
// paddle_next_record gets the nexts training record.
|
||||
//
|
||||
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
|
||||
//
|
||||
//export paddle_next_record
|
||||
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
|
||||
c := get(client)
|
||||
r, err := c.NextRecord()
|
||||
if err != nil {
|
||||
// NOTE: use errors to indicate pass ends
|
||||
if err.Error() == master.ErrAllTaskFailed.Error() ||
|
||||
err.Error() == master.ErrNoMoreAvailable.Error() ||
|
||||
err.Error() == master.ErrPassBefore.Error() {
|
||||
return -2
|
||||
}
|
||||
*record = (*C.uchar)(nil)
|
||||
return -1
|
||||
}
|
||||
|
||||
if len(r) == 0 {
|
||||
// Empty record
|
||||
*record = (*C.uchar)(nil)
|
||||
return 0
|
||||
}
|
||||
|
||||
size := C.size_t(len(r))
|
||||
*record = (*C.uchar)(C.malloc(size))
|
||||
C.memcpy(unsafe.Pointer(*record), unsafe.Pointer(&r[0]), size)
|
||||
return C.int(size)
|
||||
}
|
||||
|
||||
// paddle_request_save_model requests the master server to approve the
|
||||
// caller to save the model.
|
||||
//
|
||||
// returns 1 if the save the model request is approved, 0 if the
|
||||
// request is rejected because other trainer is saving the model, -1
|
||||
// if error happened.
|
||||
//
|
||||
//export paddle_request_save_model
|
||||
func paddle_request_save_model(client C.paddle_master_client, trainerID string, blockMS int) C.int {
|
||||
c := get(client)
|
||||
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
|
||||
if err != nil {
|
||||
log.Error("error request save model", log.Ctx{"error": err})
|
||||
return C.PADDLE_MASTER_ERROR
|
||||
}
|
||||
|
||||
if need {
|
||||
return C.PADDLE_SAVE_MODEL_OK
|
||||
}
|
||||
|
||||
return C.PADDLE_SAVE_MODEL_SKIP
|
||||
}
|
||||
|
||||
//export mem_free
|
||||
func mem_free(p unsafe.Pointer) {
|
||||
// "free" may be a better name for this function, but doing so
|
||||
// will cause calling any function of this library from Python
|
||||
// ctypes hanging.
|
||||
C.free(p)
|
||||
}
|
||||
|
||||
func main() {}
|
File diff suppressed because it is too large
Load Diff
@ -1,152 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/connection"
|
||||
"github.com/PaddlePaddle/recordio"
|
||||
)
|
||||
|
||||
const (
|
||||
totalTask = 20
|
||||
chunkPerTask = 10
|
||||
)
|
||||
|
||||
func TestGetFinishTask(t *testing.T) {
|
||||
const path = "/tmp/master_client_test_0"
|
||||
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func(l net.Listener) {
|
||||
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
|
||||
if sErr != nil {
|
||||
panic(sErr)
|
||||
}
|
||||
|
||||
server := rpc.NewServer()
|
||||
sErr = server.Register(s)
|
||||
if sErr != nil {
|
||||
panic(sErr)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
sErr = http.Serve(l, mux)
|
||||
if sErr != nil {
|
||||
panic(sErr)
|
||||
}
|
||||
}(l)
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i := 0; i < totalTask*chunkPerTask; i++ {
|
||||
w := recordio.NewWriter(f, -1, -1)
|
||||
_, err = w.Write(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// call Close to force RecordIO writing a chunk.
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Manually intialize client to avoid calling c.getRecords()
|
||||
c := &Client{}
|
||||
c.conn = connection.New()
|
||||
addr := fmt.Sprintf(":%d", p)
|
||||
ch := make(chan string, 1)
|
||||
ch <- addr
|
||||
go c.monitorMaster(ch)
|
||||
|
||||
err = c.SetDataset([]string{path})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
checkOnePass := func(i int) {
|
||||
var tasks []Task
|
||||
for idx := 0; idx < totalTask; idx++ {
|
||||
task, cErr := c.getTask(i)
|
||||
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
|
||||
t.Fatalf("error: %v, pass: %d\n", cErr, i)
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
}
|
||||
|
||||
// getting task before task finishes should return error
|
||||
_, cErr := c.getTask(i)
|
||||
if cErr == nil {
|
||||
t.Fatalf("Should get error, pass: %d\n", i)
|
||||
}
|
||||
|
||||
cErr = c.taskFinished(tasks[0].Meta.ID)
|
||||
if cErr != nil {
|
||||
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
|
||||
}
|
||||
// call taskFailed once won't put the task to failed queue, just ensure
|
||||
// the call
|
||||
cErr = c.taskFailed(tasks[0].Meta)
|
||||
if cErr != nil {
|
||||
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
|
||||
}
|
||||
|
||||
tasks = tasks[1:]
|
||||
_, cErr = c.getTask(i)
|
||||
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
|
||||
t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
|
||||
}
|
||||
|
||||
for _, task := range tasks {
|
||||
cErr = c.taskFinished(task.Meta.ID)
|
||||
if cErr != nil {
|
||||
t.Fatal(cErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
// init pass data
|
||||
c.StartGetRecords(i)
|
||||
checkOnePass(i)
|
||||
}
|
||||
}
|
@ -1,150 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
"github.com/PaddlePaddle/recordio"
|
||||
)
|
||||
|
||||
// tool function for testing output goroutine ids
|
||||
func goid() int {
|
||||
var buf [64]byte
|
||||
n := runtime.Stack(buf[:], false)
|
||||
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
|
||||
id, err := strconv.Atoi(idField)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func TestNextRecord(t *testing.T) {
|
||||
const (
|
||||
path = "/tmp/master_client_TestFull"
|
||||
total = 50
|
||||
)
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ss := strings.Split(l.Addr().String(), ":")
|
||||
p, err := strconv.Atoi(ss[len(ss)-1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func(l net.Listener) {
|
||||
s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
server := rpc.NewServer()
|
||||
err = server.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(rpc.DefaultRPCPath, server)
|
||||
err = http.Serve(l, mux)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}(l)
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
w := recordio.NewWriter(f, 1, -1)
|
||||
for i := 0; i < total; i++ {
|
||||
_, err = w.Write([]byte{byte(i)})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// start several client to test task fetching
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 4; i++ {
|
||||
wg.Add(1)
|
||||
// test for multiple concurrent clients
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// each go-routine needs a single client connection instance
|
||||
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
|
||||
if e != nil {
|
||||
t.Fatal(e)
|
||||
}
|
||||
e = c.SetDataset([]string{path})
|
||||
if e != nil {
|
||||
panic(e)
|
||||
}
|
||||
|
||||
// test for n passes
|
||||
for pass := 0; pass < 10; pass++ {
|
||||
c.StartGetRecords(pass)
|
||||
|
||||
received := make(map[byte]bool)
|
||||
taskid := 0
|
||||
for {
|
||||
r, e := c.NextRecord()
|
||||
if e != nil {
|
||||
// ErrorPassAfter will wait, else break for next pass
|
||||
if e.Error() == master.ErrPassBefore.Error() ||
|
||||
e.Error() == master.ErrNoMoreAvailable.Error() {
|
||||
break
|
||||
}
|
||||
t.Fatal(pass, taskid, "Read error:", e)
|
||||
}
|
||||
if len(r) != 1 {
|
||||
t.Fatal(pass, taskid, "Length should be 1.", r)
|
||||
}
|
||||
if received[r[0]] {
|
||||
t.Fatal(pass, taskid, "Received duplicate.", received, r)
|
||||
}
|
||||
taskid++
|
||||
received[r[0]] = true
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
@ -1,201 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/coreos/etcd/clientv3/concurrency"
|
||||
log "github.com/inconshreveable/log15"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLockPath is the default etcd master lock path.
|
||||
DefaultLockPath = "/master/lock"
|
||||
// DefaultStatePath is the default etcd key for master state.
|
||||
DefaultStatePath = "/master/state"
|
||||
// DefaultAddrPath is the default etcd key for master address.
|
||||
DefaultAddrPath = "/master/addr"
|
||||
)
|
||||
|
||||
// EtcdClient is the etcd client that the master uses for fault
|
||||
// tolerance and service registry.
|
||||
type EtcdClient struct {
|
||||
lockPath string
|
||||
statePath string
|
||||
client *clientv3.Client
|
||||
lock *concurrency.Mutex
|
||||
sess *concurrency.Session
|
||||
}
|
||||
|
||||
// NewEtcdClient creates a new EtcdClient.
|
||||
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
|
||||
log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: endpoints,
|
||||
DialTimeout: dialTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lock := concurrency.NewMutex(sess, lockPath)
|
||||
// It's fine for the lock to get stuck, in this case we have
|
||||
// multiple master servers running (only configured to have
|
||||
// one master running, but split-brain problem may cause
|
||||
// multiple master servers running), and the cluster management
|
||||
// software will kill one of them.
|
||||
log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
|
||||
err = lock.Lock(context.TODO())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
|
||||
|
||||
put := clientv3.OpPut(addrPath, addr)
|
||||
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Crit("No longer owns the master lock. Exiting.")
|
||||
panic("No longer owns the master lock. Exiting.")
|
||||
}
|
||||
|
||||
e := &EtcdClient{
|
||||
lockPath: lockPath,
|
||||
statePath: statePath,
|
||||
client: cli,
|
||||
lock: lock,
|
||||
sess: sess,
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// Save saves the state into the etcd.
|
||||
func (e *EtcdClient) Save(state []byte) error {
|
||||
ctx := context.TODO()
|
||||
put := clientv3.OpPut(e.statePath, string(state))
|
||||
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Error("No longer owns the lock, trying to lock again")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
err := e.lock.Lock(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
// We lost the master lock and can not acquire
|
||||
// it back, it means some other master is
|
||||
// already started. We don't want cluster
|
||||
// management system to kill the master server
|
||||
// who is holding the lock and running
|
||||
// correctly. So the most feasible solution is
|
||||
// to kill current master server. The current
|
||||
// state is not saved, but the trainer's RPC
|
||||
// call will fail, so the trainer will retry.
|
||||
log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
|
||||
panic("Could not acquire the lock at %s: %v. Exiting.")
|
||||
}
|
||||
log.Info("Successfully acquired lock at %s.", e.lockPath)
|
||||
return e.Save(state)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load loads the state from etcd.
|
||||
func (e *EtcdClient) Load() ([]byte, error) {
|
||||
ctx := context.TODO()
|
||||
get := clientv3.OpGet(e.statePath)
|
||||
|
||||
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Error("No longer owns the lock, trying to lock and load again.")
|
||||
err = e.lock.Lock(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return e.Load()
|
||||
}
|
||||
|
||||
kvs := resp.Responses[0].GetResponseRange().Kvs
|
||||
if len(kvs) == 0 {
|
||||
// No state exists
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
state := kvs[0].Value
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the etcd client gracefully.
|
||||
func (e *EtcdClient) Shutdown() error {
|
||||
err := e.sess.Close()
|
||||
newErr := e.client.Close()
|
||||
if newErr != nil {
|
||||
if err == nil {
|
||||
err = newErr
|
||||
} else {
|
||||
log.Error("shutdown error", log.Ctx{"error": newErr})
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// GetKey gets the value by the specify key.
|
||||
func GetKey(c *clientv3.Client, key string, timeout time.Duration) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
resp, err := c.Get(ctx, key)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
kvs := resp.Kvs
|
||||
if len(kvs) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
v := kvs[0].Value
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
// watchKey watches the specify key and send to valChan if there is some event.
|
||||
func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
|
||||
rch := c.Watch(context.Background(), key)
|
||||
for wresp := range rch {
|
||||
for _, ev := range wresp.Events {
|
||||
// if received event is DELETE, the value will be an empty string
|
||||
log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
|
||||
valChan <- string(ev.Kv.Value)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,47 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master
|
||||
|
||||
import "sync"
|
||||
|
||||
// InMemStore is an in memory implementation of Store interface.
|
||||
//
|
||||
// It does not tolerate the fault that causes the program to crash.
|
||||
type InMemStore struct {
|
||||
mu sync.Mutex
|
||||
buf []byte
|
||||
}
|
||||
|
||||
// Save saves the state into the in-memory store.
|
||||
func (m *InMemStore) Save(state []byte) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.buf = state
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load loads the state from the in-memory store.
|
||||
func (m *InMemStore) Load() ([]byte, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.buf, nil
|
||||
}
|
||||
|
||||
// Shutdown shuts down the in mem store.
|
||||
func (m *InMemStore) Shutdown() error {
|
||||
return nil
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,52 +0,0 @@
|
||||
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package master
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestPartitionCount(t *testing.T) {
|
||||
cs := make([]Chunk, 100)
|
||||
ts := partition(cs, 5)
|
||||
if len(ts) != 20 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
cs = make([]Chunk, 101)
|
||||
ts = partition(cs, 5)
|
||||
if len(ts) != 21 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
ts = partition(cs, 1)
|
||||
if len(ts) != 101 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
|
||||
ts = partition(cs, 0)
|
||||
if len(ts) != 101 {
|
||||
t.Error(len(ts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPartionIndex(t *testing.T) {
|
||||
cs := make([]Chunk, 100)
|
||||
ts := partition(cs, 20)
|
||||
for i := range ts {
|
||||
// test auto increament ids
|
||||
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
|
||||
t.Error(ts[i], i)
|
||||
}
|
||||
}
|
||||
}
|
@ -1,72 +0,0 @@
|
||||
package master_test
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/coreos/etcd/embed"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewServiceWithEtcd(t *testing.T) {
|
||||
// setup an embed etcd server
|
||||
etcdDir, err := ioutil.TempDir("", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cfg := embed.NewConfig()
|
||||
lpurl, _ := url.Parse("http://localhost:0")
|
||||
lcurl, _ := url.Parse("http://localhost:0")
|
||||
cfg.LPUrls = []url.URL{*lpurl}
|
||||
cfg.LCUrls = []url.URL{*lcurl}
|
||||
cfg.Dir = etcdDir
|
||||
e, err := embed.StartEtcd(cfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
e.Close()
|
||||
if err := os.RemoveAll(etcdDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
<-e.Server.ReadyNotify()
|
||||
|
||||
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
|
||||
endpoint := "127.0.0.1:" + port
|
||||
|
||||
ep := []string{endpoint}
|
||||
masterAddr := "127.0.0.1:3306"
|
||||
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = master.NewService(store, 10, 10, 3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: ep,
|
||||
DialTimeout: 3 * time.Second,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v, err := master.GetKey(cli, master.DefaultAddrPath, 3*time.Second)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := cli.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// test master process registry itself into etcd server.
|
||||
assert.Equal(t, masterAddr, v, "master process should registry itself into etcd server.")
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
# Ignore everything in this directory
|
||||
*
|
||||
# Except this file
|
||||
!.gitignore
|
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
if(WITH_TESTING)
|
||||
go_test(pserver_test DEPS paddle_go_optimizer gen_proto_go)
|
||||
endif()
|
@ -1,17 +0,0 @@
|
||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
if(WITH_TESTING)
|
||||
go_test(pserver_client_test DEPS paddle_go_optimizer)
|
||||
endif()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue