commit
36f8b84809
@ -0,0 +1,29 @@
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
|
||||
|
||||
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3)
|
||||
|
||||
ExternalProject_Add(
|
||||
eigen3
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
# for latest version, please get from official website
|
||||
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
|
||||
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
|
||||
|
||||
# for no-ssl http support, please get from bazel's mirror
|
||||
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
|
||||
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
|
||||
|
||||
# get from github mirror
|
||||
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
|
||||
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048"
|
||||
PREFIX ${EIGEN_SOURCE_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
TEST_COMMAND ""
|
||||
)
|
||||
|
||||
LIST(APPEND external_project_dependencies eigen3)
|
@ -0,0 +1,124 @@
|
||||
# Design of Scope in Paddle
|
||||
|
||||
## Overview
|
||||
|
||||
Scope is an important concept in programming languages, which defines a program region that a set of bindings between names and entities applies. In a specific scope, a valid name is uniquely associated with an entity, such as a variable. And in another scope, this name may refer to other entity or nothing at all. It clearly restricts the visibility and validity of names in a program. Hence **Scope** is introduced to PaddlePaddle to manage variables in context. But different from the original abstract concept, Scope now becomes an object with two important attributes:
|
||||
|
||||
- Scope is an association of a name to variable.
|
||||
- Variables in a parent scope can be retrieved from local scope.
|
||||
|
||||
A detailed explanation of these two attributes goes as following.
|
||||
|
||||
|
||||
## Scope is an association of a name to variable.
|
||||
|
||||
Scope is an association of a name to variable. All variables belong to `Scope`. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. One net can run in different scopes and update different variable in the scope.
|
||||
|
||||
|
||||
1. Scope only contains a map of a name to variable.
|
||||
|
||||
All parameters, data, states in a Net should be variables and stored inside a scope. Each op should get inputs and outputs to do computation from a scope, such as data buffer, state(momentum) etc.
|
||||
|
||||
1. Variable can only be created by Scope and a variable can only be got from Scope. User cannot create or get a variable outside a scope. This is a constraints of our framework, and will keep our framework simple and clear.
|
||||
|
||||
1. Scope only contains methods that are used to Create and Get Variables. Scope do not contain Operators and have no information to run them.
|
||||
`Net` is designed to drive the computation and Scope only contains a map of variables. There is no computation logic inside a `Scope`. Scope just handles the lifetime management of variables.
|
||||
- `Create` is used to create a Variable by its name and add the mapping relation.
|
||||
- `Get` is used to find a Variable by name.
|
||||
|
||||
1. Every variable only belongs to one certain Scope.
|
||||
|
||||
Variable can not belong to many scopes. If you want to use variables from parent scope, you can use `parent scope`.
|
||||
|
||||
1. Scope should destruct all Variables inside it when itself is destructed. User can never store `Variable` pointer somewhere else.
|
||||
|
||||
Because Variable can only be got from Scope. When destroying Scope, we also need to destroy all the Variables in it. If user store `Variable` pointer to private data member or some global variable, the pointer will be a invalid pointer when associated `Scope` is destroyed.
|
||||
|
||||
```cpp
|
||||
class Scope {
|
||||
public:
|
||||
Variable* CreateVariable(const std::string& name);
|
||||
const Variable* GetVariable(const std::string& name) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
|
||||
};
|
||||
```
|
||||
|
||||
|
||||
## Parent scope and local scope
|
||||
|
||||
Just like [scope](https://en.wikipedia.org/wiki/Scope_(computer_science)) in programming languages, `Scope` in the neural network can also be a local scope. There are two attributes about local scope.
|
||||
|
||||
1. We can create local variables in a local scope. When that local scope are destroyed, all local variables should also be destroyed.
|
||||
2. Variables in a parent scope can be retrieved from local scopes of that parent scope, i.e., when user get a variable from a scope, it will try to search this variable in current scope. If there is no such variable in the local scope, `scope` will keep searching from its parent, until the variable is found or there is no parent.
|
||||
|
||||
```cpp
|
||||
class Scope {
|
||||
public:
|
||||
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
|
||||
|
||||
Variable* GetVariable(const std::string& name) const {
|
||||
auto it = vars_.find(name);
|
||||
if (it != vars_.end()) {
|
||||
return it->second.get();
|
||||
} else if (parent_ != nullptr) {
|
||||
return parent_->GetVariable(name);
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<Scope> parent_ {nullptr};
|
||||
};
|
||||
```
|
||||
|
||||
In `Scope` class, there is a private data member called `parent_`. `parent_` is a smart pointer to its parent scope. When user `Get` a variable by its `name`, the `name` will be searched inside the current scope. If the variable cannot be found locally and parent scope is not a `nullptr`, the variable will be searched inside that parent scope. `parent_` pointer's default value is `nullptr`. It means that the scope is a global scope when `parent_` is nullptr.
|
||||
|
||||
A local scope is very useful when we implement Recurrent Neural Network. Each timestep of an RNN should be a `Net`. Each `Net` of timestep (`StepNet` for short) should use an independent local scope. Just like variables in a while loop is inside a local scope in programming languages. By using a single `StepNet` and changing local scope, we can implement an RNN easily.
|
||||
|
||||
# Interface Design
|
||||
|
||||
```cpp
|
||||
class Variable {
|
||||
private:
|
||||
Variable() = default;
|
||||
friend class Scope;
|
||||
};
|
||||
|
||||
class Scope {
|
||||
private:
|
||||
Scope(const std::shared_ptr<Scope>& parent = nullptr);
|
||||
|
||||
public:
|
||||
static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr);
|
||||
|
||||
// return nullptr if not found.
|
||||
Variable* GetVariable(const std::string& name) const;
|
||||
|
||||
// return if already contains same name variable.
|
||||
Variable* CreateVariable(const std::string& name);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Scope> parent_;
|
||||
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
|
||||
};
|
||||
```
|
||||
## Only scope can create a variable
|
||||
|
||||
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`.
|
||||
|
||||
## When scope destroyed, all variables inside this scope should be destroyed together
|
||||
|
||||
The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
|
||||
|
||||
## Sharing a parent scope
|
||||
|
||||
Local scope contains a `parent_` pointer. It is a linked-list for scopes. Using a `shared_ptr` because when a local scope is using, its parents cannot be destroyed.
|
||||
|
||||
Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shared pointer. We cannot construct a scope variable, because it cannot be passed to other scope as `parent` pointer.
|
||||
|
||||
## Orthogonal interface
|
||||
|
||||
`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily.
|
@ -1,44 +0,0 @@
|
||||
if(NOT CMAKE_Go_COMPILER)
|
||||
if(NOT $ENV{GO_COMPILER} STREQUAL "")
|
||||
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT)
|
||||
|
||||
if(CMAKE_Go_FLAGS_ENV_INIT)
|
||||
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler")
|
||||
endif()
|
||||
|
||||
if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT})
|
||||
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
set(Go_BIN_PATH
|
||||
$ENV{GOPATH}
|
||||
$ENV{GOROOT}
|
||||
$ENV{GOROOT}/../bin
|
||||
$ENV{GO_COMPILER}
|
||||
/usr/bin
|
||||
/usr/local/bin
|
||||
)
|
||||
|
||||
if(CMAKE_Go_COMPILER_INIT)
|
||||
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler")
|
||||
else()
|
||||
find_program(CMAKE_Go_COMPILER
|
||||
NAMES go
|
||||
PATHS ${Go_BIN_PATH}
|
||||
)
|
||||
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION)
|
||||
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}")
|
||||
message("-- The Golang compiler identification is ${VERSION}")
|
||||
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
mark_as_advanced(CMAKE_Go_COMPILER)
|
||||
|
||||
configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in
|
||||
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)
|
||||
|
||||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
|
@ -1,8 +0,0 @@
|
||||
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@")
|
||||
set(CMAKE_Go_COMPILER_LOADED 1)
|
||||
|
||||
set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go)
|
||||
set(CMAKE_Go_LINKER_PREFERENCE 40)
|
||||
set(CMAKE_Go_OUTPUT_EXTENSION .o)
|
||||
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1)
|
||||
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")
|
@ -1,7 +0,0 @@
|
||||
if(NOT CMAKE_Go_COMPILE_OBJECT)
|
||||
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ")
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_Go_LINK_EXECUTABLE)
|
||||
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ")
|
||||
endif()
|
@ -1 +0,0 @@
|
||||
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "")
|
@ -1,45 +0,0 @@
|
||||
# Setting Paddle Compile Flags
|
||||
include(CheckCXXCompilerFlag)
|
||||
include(CheckCCompilerFlag)
|
||||
include(CheckCXXSymbolExists)
|
||||
include(CheckTypeSize)
|
||||
|
||||
function(CheckCompilerCXX11Flag)
|
||||
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
|
||||
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
|
||||
endif()
|
||||
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
|
||||
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
|
||||
# https://gist.github.com/yamaya/2924292
|
||||
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X
|
||||
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1)
|
||||
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
|
||||
endif()
|
||||
else()
|
||||
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
|
||||
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
CheckCompilerCXX11Flag()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
|
||||
|
||||
# Common gpu architectures: Kepler, Maxwell
|
||||
foreach(capability 30 35 50)
|
||||
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}")
|
||||
endforeach()
|
||||
|
||||
if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0")
|
||||
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52")
|
||||
endif()
|
||||
|
||||
# Modern gpu architectures: Pascal
|
||||
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
|
||||
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
|
||||
endif()
|
||||
|
||||
set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})
|
@ -1,48 +0,0 @@
|
||||
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
|
||||
file(MAKE_DIRECTORY ${GOPATH})
|
||||
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle")
|
||||
file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH})
|
||||
|
||||
function(GO_LIBRARY NAME BUILD_TYPE)
|
||||
if(BUILD_TYPE STREQUAL "STATIC")
|
||||
set(BUILD_MODE -buildmode=c-archive)
|
||||
set(LIB_NAME "lib${NAME}.a")
|
||||
else()
|
||||
set(BUILD_MODE -buildmode=c-shared)
|
||||
if(APPLE)
|
||||
set(LIB_NAME "lib${NAME}.dylib")
|
||||
else()
|
||||
set(LIB_NAME "lib${NAME}.so")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
|
||||
file(RELATIVE_PATH rel ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
# find Paddle directory.
|
||||
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
|
||||
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
|
||||
get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY)
|
||||
|
||||
# automatically get all dependencies specified in the source code
|
||||
# for given target.
|
||||
add_custom_target(${NAME}_goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
|
||||
|
||||
# make a symlink that references Paddle inside $GOPATH, so go get
|
||||
# will use the local changes in Paddle rather than checkout Paddle
|
||||
# in github.
|
||||
add_custom_target(${NAME}_copyPaddle
|
||||
COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
|
||||
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
|
||||
add_dependencies(${NAME}_goGet ${NAME}_copyPaddle)
|
||||
|
||||
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
|
||||
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
|
||||
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
|
||||
${CMAKE_GO_FLAGS} ${GO_SOURCE}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
|
||||
add_dependencies(${NAME} ${NAME}_goGet)
|
||||
|
||||
endfunction(GO_LIBRARY)
|
@ -1,45 +1,69 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/rpc"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/namsral/flag"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/PaddlePaddle/Paddle/go/master"
|
||||
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
port := flag.Int("port", 8080, "port of the master server.")
|
||||
|
||||
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
|
||||
ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.")
|
||||
endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.")
|
||||
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
|
||||
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
|
||||
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
|
||||
flag.Parse()
|
||||
|
||||
if *faultTolerance {
|
||||
panic("fault tolernance not implemented.")
|
||||
if *endpoints == "" {
|
||||
log.Warningln("-endpoints not set, fault tolerance not be enabled.")
|
||||
}
|
||||
|
||||
var store master.Store
|
||||
if *endpoints != "" {
|
||||
eps := strings.Split(*endpoints, ",")
|
||||
ip, err := networkhelper.GetExternalIP()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", ip, *port)
|
||||
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
} else {
|
||||
store = &master.InMemStore{}
|
||||
}
|
||||
|
||||
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
|
||||
err := rpc.Register(s)
|
||||
err = rpc.Register(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rpc.HandleHTTP()
|
||||
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = http.Serve(l, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,172 @@
|
||||
package master
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/etcd/clientv3"
|
||||
"github.com/coreos/etcd/clientv3/concurrency"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultLockPath is the default etcd master lock path.
|
||||
DefaultLockPath = "/master/lock"
|
||||
// DefaultStatePath is the default etcd key for master state.
|
||||
DefaultStatePath = "/master/state"
|
||||
// DefaultAddrPath is the default etcd key for master address.
|
||||
DefaultAddrPath = "/master/addr"
|
||||
)
|
||||
|
||||
// EtcdClient is the etcd client that the master uses for fault
|
||||
// tolerance and service registry.
|
||||
type EtcdClient struct {
|
||||
lockPath string
|
||||
statePath string
|
||||
client *clientv3.Client
|
||||
lock *concurrency.Mutex
|
||||
}
|
||||
|
||||
// NewEtcdClient creates a new EtcdClient.
|
||||
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
|
||||
log.Debugf("Connecting to etcd at %v", endpoints)
|
||||
// TODO(helin): gracefully shutdown etcd store. Becuase etcd
|
||||
// store holds a etcd lock, even though the lock will expire
|
||||
// when the lease timeout, we need to implement graceful
|
||||
// shutdown to release the lock.
|
||||
cli, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: endpoints,
|
||||
DialTimeout: dialTimeout,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lock := concurrency.NewMutex(sess, lockPath)
|
||||
// It's fine for the lock to get stuck, in this case we have
|
||||
// multiple master servers running (only configured to have
|
||||
// one master running, but split-brain problem may cuase
|
||||
// multiple master servers running), and the cluster management
|
||||
// software will kill one of them.
|
||||
log.Debugf("Trying to acquire lock at %s.", lockPath)
|
||||
err = lock.Lock(context.TODO())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Successfully acquired lock at %s.", lockPath)
|
||||
|
||||
put := clientv3.OpPut(addrPath, string(addr))
|
||||
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Fatal("No longer owns the master lock. Exiting.")
|
||||
}
|
||||
|
||||
e := &EtcdClient{
|
||||
lockPath: lockPath,
|
||||
statePath: statePath,
|
||||
client: cli,
|
||||
lock: lock,
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// Save saves the state into the etcd.
|
||||
func (e *EtcdClient) Save(state []byte) error {
|
||||
ctx := context.TODO()
|
||||
put := clientv3.OpPut(e.statePath, string(state))
|
||||
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Errorln("No longer owns the lock, trying to lock again")
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
err := e.lock.Lock(ctx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
// We lost the master lock and can not acquire
|
||||
// it back, it means some other master is
|
||||
// already started. We don't want cluster
|
||||
// managment system to kill the master server
|
||||
// who is holding the lock and running
|
||||
// correctly. So the most feasible solution is
|
||||
// to kill current master server. The current
|
||||
// state is not saved, but the trainer's RPC
|
||||
// call will fail, so the trainer will retry.
|
||||
log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err)
|
||||
}
|
||||
log.Infof("Successfully acquired lock at %s.", e.lockPath)
|
||||
return e.Save(state)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Load loads the state from etcd.
|
||||
func (e *EtcdClient) Load() ([]byte, error) {
|
||||
ctx := context.TODO()
|
||||
get := clientv3.OpGet(e.statePath)
|
||||
|
||||
resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !resp.Succeeded {
|
||||
log.Errorln("No longer owns the lock, trying to lock and load again.")
|
||||
err = e.lock.Lock(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return e.Load()
|
||||
}
|
||||
|
||||
kvs := resp.Responses[0].GetResponseRange().Kvs
|
||||
if len(kvs) == 0 {
|
||||
// No state exists
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
state := kvs[0].Value
|
||||
return state, nil
|
||||
}
|
||||
|
||||
// GetKey gets the value by the specify key.
|
||||
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
|
||||
resp, err := c.Get(ctx, key)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
kvs := resp.Kvs
|
||||
if len(kvs) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
v := kvs[0].Value
|
||||
return string(v), nil
|
||||
}
|
||||
|
||||
// WatchKey watches the specify key and send to valChan if there is some event.
|
||||
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
|
||||
rch := c.Watch(context.Background(), key)
|
||||
for wresp := range rch {
|
||||
for _, ev := range wresp.Events {
|
||||
// if received event is DELETE, the value will be an empty string
|
||||
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value)
|
||||
valChan <- string(ev.Kv.Value)
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue