Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into cpu_mem
commit
68ab1ef4db
@ -0,0 +1,20 @@
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
|
||||
|
||||
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3)
|
||||
|
||||
ExternalProject_Add(
|
||||
eigen3
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
|
||||
URL_MD5 "1a47e78efe365a97de0c022d127607c3"
|
||||
PREFIX ${EIGEN_SOURCE_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
TEST_COMMAND ""
|
||||
)
|
||||
|
||||
LIST(APPEND external_project_dependencies eigen3)
|
@ -0,0 +1,124 @@
|
||||
# Design of Scope in Paddle
|
||||
|
||||
## Overview
|
||||
|
||||
Scope is an important concept in programming languages, which defines a program region that a set of bindings between names and entities applies. In a specific scope, a valid name is uniquely associated with an entity, such as a variable. And in another scope, this name may refer to other entity or nothing at all. It clearly restricts the visibility and validity of names in a program. Hence **Scope** is introduced to PaddlePaddle to manage variables in context. But different from the original abstract concept, Scope now becomes an object with two important attributes:
|
||||
|
||||
- Scope is an association of a name to variable.
|
||||
- Variables in a parent scope can be retrieved from local scope.
|
||||
|
||||
A detailed explanation of these two attributes goes as following.
|
||||
|
||||
|
||||
## Scope is an association of a name to variable.
|
||||
|
||||
Scope is an association of a name to variable. All variables belong to `Scope`. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. One net can run in different scopes and update different variable in the scope.
|
||||
|
||||
|
||||
1. Scope only contains a map of a name to variable.
|
||||
|
||||
All parameters, data, states in a Net should be variables and stored inside a scope. Each op should get inputs and outputs to do computation from a scope, such as data buffer, state(momentum) etc.
|
||||
|
||||
1. Variable can only be created by Scope and a variable can only be got from Scope. User cannot create or get a variable outside a scope. This is a constraints of our framework, and will keep our framework simple and clear.
|
||||
|
||||
1. Scope only contains methods that are used to Create and Get Variables. Scope do not contain Operators and have no information to run them.
|
||||
`Net` is designed to drive the computation and Scope only contains a map of variables. There is no computation logic inside a `Scope`. Scope just handles the lifetime management of variables.
|
||||
- `Create` is used to create a Variable by its name and add the mapping relation.
|
||||
- `Get` is used to find a Variable by name.
|
||||
|
||||
1. Every variable only belongs to one certain Scope.
|
||||
|
||||
Variable can not belong to many scopes. If you want to use variables from parent scope, you can use `parent scope`.
|
||||
|
||||
1. Scope should destruct all Variables inside it when itself is destructed. User can never store `Variable` pointer somewhere else.
|
||||
|
||||
Because Variable can only be got from Scope. When destroying Scope, we also need to destroy all the Variables in it. If user store `Variable` pointer to private data member or some global variable, the pointer will be a invalid pointer when associated `Scope` is destroyed.
|
||||
|
||||
```cpp
|
||||
class Scope {
|
||||
public:
|
||||
Variable* CreateVariable(const std::string& name);
|
||||
const Variable* GetVariable(const std::string& name) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::unique_ptr<Vairable>> vars_;
|
||||
};
|
||||
```
|
||||
|
||||
|
||||
## Parent scope and local scope
|
||||
|
||||
Just like [scope](https://en.wikipedia.org/wiki/Scope_(computer_science)) in programming languages, `Scope` in the neural network can also be a local scope. There are two attributes about local scope.
|
||||
|
||||
1. We can create local variables in a local scope. When that local scope are destroyed, all local variables should also be destroyed.
|
||||
2. Variables in a parent scope can be retrieved from local scopes of that parent scope, i.e., when user get a variable from a scope, it will try to search this variable in current scope. If there is no such variable in the local scope, `scope` will keep searching from its parent, until the variable is found or there is no parent.
|
||||
|
||||
```cpp
|
||||
class Scope {
|
||||
public:
|
||||
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
|
||||
|
||||
Variable* GetVariable(const std::string& name) const {
|
||||
Variable* var = GetVarLocally(name);
|
||||
if (var != nullptr) {
|
||||
return var;
|
||||
} else if (parent_ != nullptr) {
|
||||
return parent_->GetVariable(name);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Scope> parent_ {nullptr};
|
||||
};
|
||||
```
|
||||
|
||||
In `Scope` class, there is a private data member called `parent_`. `parent_` is a smart pointer to its parent scope. When user `Get` a variable by its `name`, the `name` will be searched inside the current scope. If the variable cannot be found locally and parent scope is not a `nullptr`, the variable will be searched inside that parent scope. `parent_` pointer's default value is `nullptr`. It means that the scope is a global scope when `parent_` is nullptr.
|
||||
|
||||
A local scope is very useful when we implement Recurrent Neural Network. Each timestep of an RNN should be a `Net`. Each `Net` of timestep (`StepNet` for short) should use an independent local scope. Just like variables in a while loop is inside a local scope in programming languages. By using a single `StepNet` and changing local scope, we can implement an RNN easily.
|
||||
|
||||
# Interface Design
|
||||
|
||||
```cpp
|
||||
class Variable {
|
||||
private:
|
||||
Variable() = default;
|
||||
friend class Scope;
|
||||
};
|
||||
|
||||
class Scope {
|
||||
private:
|
||||
Scope(const std::shared_ptr<Scope>& parent = nullptr);
|
||||
|
||||
public:
|
||||
static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr);
|
||||
|
||||
// return nullptr if not found.
|
||||
Variable* GetVariable(const std::string& name) const;
|
||||
|
||||
// return Error if already contains same name variable.
|
||||
Error CreateVariable(const std::string& name);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Scope> parent_;
|
||||
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
|
||||
};
|
||||
```
|
||||
## Only scope can create a variable
|
||||
|
||||
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`.
|
||||
|
||||
## When scope destroyed, all variables inside this scope should be destroyed together
|
||||
|
||||
The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
|
||||
|
||||
## Sharing a parent scope
|
||||
|
||||
Local scope contains a `parent_` pointer. It is a linked-list for scopes. Using a `shared_ptr` because when a local scope is using, its parents cannot be destroyed.
|
||||
|
||||
Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shared pointer. We cannot construct a scope variable, because it cannot be passed to other scope as `parent` pointer.
|
||||
|
||||
## Orthogonal interface
|
||||
|
||||
`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily.
|
@ -1,44 +0,0 @@
|
||||
if(NOT CMAKE_Go_COMPILER)
|
||||
if(NOT $ENV{GO_COMPILER} STREQUAL "")
|
||||
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT)
|
||||
|
||||
if(CMAKE_Go_FLAGS_ENV_INIT)
|
||||
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT})
|
||||
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
set(Go_BIN_PATH
|
||||
$ENV{GOPATH}
|
||||
$ENV{GOROOT}
|
||||
$ENV{GOROOT}/../bin
|
||||
$ENV{GO_COMPILER}
|
||||
/usr/bin
|
||||
/usr/local/bin
|
||||
)
|
||||
|
||||
if(CMAKE_Go_COMPILER_INIT)
|
||||
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler")
|
||||
else()
|
||||
find_program(CMAKE_Go_COMPILER
|
||||
NAMES go
|
||||
PATHS ${Go_BIN_PATH}
|
||||
)
|
||||
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION)
|
||||
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}")
|
||||
message("-- The Golang compiler identification is ${VERSION}")
|
||||
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CMAKE_Go_COMPILER)
|
||||
|
||||
configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in
|
||||
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)
|
||||
|
||||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
|
@ -1,8 +0,0 @@
|
||||
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@")
|
||||
set(CMAKE_Go_COMPILER_LOADED 1)
|
||||
|
||||
set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go)
|
||||
set(CMAKE_Go_LINKER_PREFERENCE 40)
|
||||
set(CMAKE_Go_OUTPUT_EXTENSION .o)
|
||||
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1)
|
||||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
|
@ -1,7 +0,0 @@
|
||||
if(NOT CMAKE_Go_COMPILE_OBJECT)
|
||||
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ")
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_Go_LINK_EXECUTABLE)
|
||||
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ")
|
||||
endif()
|
@ -1 +0,0 @@
|
||||
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "")
|
@ -1,45 +0,0 @@
|
||||
# Setting Paddle Compile Flags
|
||||
include(CheckCXXCompilerFlag)
|
||||
include(CheckCCompilerFlag)
|
||||
include(CheckCXXSymbolExists)
|
||||
include(CheckTypeSize)
|
||||
|
||||
function(CheckCompilerCXX11Flag)
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
|
||||
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
|
||||
endif()
|
||||
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
|
||||
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
|
||||
# https://gist.github.com/yamaya/2924292
|
||||
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X
|
||||
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1)
|
||||
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
|
||||
endif()
|
||||
else()
|
||||
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
|
||||
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
CheckCompilerCXX11Flag()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
|
||||
# Common gpu architectures: Kepler, Maxwell
|
||||
foreach(capability 30 35 50)
|
||||
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}")
|
||||
endforeach()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0")
|
||||
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52")
|
||||
endif()
|
||||
|
||||
# Modern gpu architectures: Pascal
|
||||
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
|
||||
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
|
||||
endif()
|
||||
|
||||
set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})
|
@ -1,48 +0,0 @@
|
||||
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
|
||||
file(MAKE_DIRECTORY ${GOPATH})
|
||||
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle")
|
||||
file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH})
|
||||
|
||||
function(GO_LIBRARY NAME BUILD_TYPE)
|
||||
if(BUILD_TYPE STREQUAL "STATIC")
|
||||
set(BUILD_MODE -buildmode=c-archive)
|
||||
set(LIB_NAME "lib${NAME}.a")
|
||||
else()
|
||||
set(BUILD_MODE -buildmode=c-shared)
|
||||
if(APPLE)
|
||||
set(LIB_NAME "lib${NAME}.dylib")
|
||||
else()
|
||||
set(LIB_NAME "lib${NAME}.so")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
|
||||
file(RELATIVE_PATH rel ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
# find Paddle directory.
|
||||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
|
||||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
|
||||
get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY)
|
||||
|
||||
# automatically get all dependencies specified in the source code
|
||||
# for given target.
|
||||
add_custom_target(${NAME}_goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
|
||||
|
||||
# make a symlink that references Paddle inside $GOPATH, so go get
|
||||
# will use the local changes in Paddle rather than checkout Paddle
|
||||
# in github.
|
||||
add_custom_target(${NAME}_copyPaddle
|
||||
COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
|
||||
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
|
||||
add_dependencies(${NAME}_goGet ${NAME}_copyPaddle)
|
||||
|
||||
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
|
||||
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
|
||||
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
|
||||
${CMAKE_GO_FLAGS} ${GO_SOURCE}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
|
||||
add_dependencies(${NAME} ${NAME}_goGet)
|
||||
|
||||
endfunction(GO_LIBRARY)
|
@ -1,14 +1,3 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
|
||||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
|
||||
|
||||
project(cxx_go C Go)
|
||||
|
||||
include(golang)
|
||||
include(flags)
|
||||
|
||||
go_library(paddle_pserver_cclient STATIC)
|
||||
|
||||
add_subdirectory(test)
|
||||
|
@ -1,22 +1,3 @@
|
||||
cmake_minimum_required(VERSION 3.0)
|
||||
|
||||
add_executable(main main.c)
|
||||
add_dependencies(main paddle_pserver_cclient)
|
||||
add_executable(test_cclient test_cclient.c)
|
||||
add_dependencies(test_cclient paddle_pserver_cclient)
|
||||
|
||||
if(APPLE)
|
||||
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
|
||||
else()
|
||||
set(CMAKE_EXE_LINKER_FLAGS "-pthread")
|
||||
endif()
|
||||
|
||||
if(PROJ_ROOT)
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR}/..)
|
||||
target_link_libraries(main ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread)
|
||||
target_link_libraries(test_cclient ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread)
|
||||
else(PROJ_ROOT)
|
||||
include_directories(${CMAKE_BINARY_DIR})
|
||||
target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
|
||||
target_link_libraries(test_cclient ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
|
||||
endif(PROJ_ROOT)
|
||||
cc_library(main SRCS main.c DEPS paddle_pserver_cclient)
|
||||
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient)
|
||||
|
@ -0,0 +1,181 @@
|
||||
package pserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/coreos/etcd/clientv3/concurrency"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// EtcdClient is the etcd client that the pserver uses for fault
|
||||
// tolerance, service registry and coordination.
|
||||
type EtcdClient struct {
|
||||
numPservers int
|
||||
etcdEndpoints string
|
||||
etcdClient *clientv3.Client
|
||||
// etcdTimeout is also used as retry intervals.
|
||||
etcdTimeout time.Duration
|
||||
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
|
||||
externalIP string
|
||||
// desired number of pservers in the job.
|
||||
// assume desired will not change during one training job.
|
||||
desired int
|
||||
}
|
||||
|
||||
// NewEtcdClient creates an EtcdClient
|
||||
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
|
||||
return &EtcdClient{
|
||||
etcdTimeout: timeout,
|
||||
numPservers: numPservers,
|
||||
etcdEndpoints: endpoints,
|
||||
}
|
||||
}
|
||||
|
||||
// Register registers the pserver on etcd
|
||||
//
|
||||
// Register returns the index of the current pserver.
|
||||
func (e *EtcdClient) Register() (int, error) {
|
||||
|
||||
var err error
|
||||
e.externalIP, err = networkhelper.GetExternalIP()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// initialize connection to etcd.
|
||||
ep := strings.Split(e.etcdEndpoints, ",")
|
||||
for {
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: ep,
|
||||
DialTimeout: e.etcdTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("connect to etcd error: %v", err)
|
||||
time.Sleep(e.etcdTimeout)
|
||||
continue
|
||||
}
|
||||
e.etcdClient = cli
|
||||
log.Debugf("inited client to %s", e.etcdEndpoints)
|
||||
break
|
||||
}
|
||||
// init /ps_desired using transaction, for multiple pservers may want to write
|
||||
// it at the same time.
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
_, err := e.initDesiredPsercers(ctx, e.numPservers)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
time.Sleep(e.etcdTimeout)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
// TODO: when implementing extending or reducing pservers, /ps_desired is
|
||||
// changed, then we need to watch /ps_desired node for events. For now, just
|
||||
// write once when init and read from it.
|
||||
// wait and set s.desired init value
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
resp, err := e.etcdClient.Get(ctx, PsDesired)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Errorf("getting %s error: %v", PsDesired, err)
|
||||
time.Sleep(e.etcdTimeout)
|
||||
continue
|
||||
}
|
||||
if len(resp.Kvs) != 0 {
|
||||
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
|
||||
if err != nil {
|
||||
log.Errorf("value of %s invalid %v\n", PsDesired, err)
|
||||
time.Sleep(e.etcdTimeout)
|
||||
// NOTE: wait util ps_desired value change
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var pserverIdx int
|
||||
// try register pserver node on etcd
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
var err error
|
||||
pserverIdx, err = e.registerPserverEtcd(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
time.Sleep(e.etcdTimeout)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return pserverIdx, nil
|
||||
}
|
||||
|
||||
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
|
||||
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
|
||||
dsStr := c.Get(PsDesired)
|
||||
if dsStr == "" {
|
||||
c.Put(PsDesired, strconv.Itoa(numPservers))
|
||||
}
|
||||
return nil
|
||||
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
|
||||
}
|
||||
|
||||
// registerPserverEtcd registers pserver node on etcd using transaction.
|
||||
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
|
||||
var idx int
|
||||
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
|
||||
registered := false
|
||||
for i := 0; i < e.desired; i++ {
|
||||
psKey := "/ps/" + strconv.Itoa(i)
|
||||
log.Debugf("checking %s", psKey)
|
||||
ps := c.Get(psKey)
|
||||
log.Debugf("got value (%s) for key: %s", ps, psKey)
|
||||
|
||||
if ps == "" {
|
||||
resp, err := e.etcdClient.Grant(context.TODO(), 5)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// find the first id and write info
|
||||
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
|
||||
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
|
||||
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
|
||||
if kaerr != nil {
|
||||
log.Errorf("keepalive etcd node error: %v", kaerr)
|
||||
return kaerr
|
||||
}
|
||||
|
||||
// Eat the keep alive message so etcd
|
||||
// will not expire the lease.
|
||||
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
|
||||
ka := <-ch
|
||||
log.Debugf("keepalive: %d\n", ka.TTL)
|
||||
}(ch)
|
||||
log.Debug("register finished")
|
||||
idx = i
|
||||
registered = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if registered == true {
|
||||
return nil
|
||||
}
|
||||
return errors.New("not registerd, may due to already have enough pservers")
|
||||
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return idx, nil
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue