Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into cpu_mem

cblas_new
liaogang 8 years ago
commit 68ab1ef4db

@ -1,4 +1,3 @@
group: deprecated-2017Q2
language: cpp
cache:
directories:

@ -93,6 +93,7 @@ include(external/openblas) # download, build, install openblas
include(external/swig) # download, build, install swig
include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any
include(external/eigen) # download eigen3
include(generic) # simplify cmake module
include(package) # set paddle packages

@ -0,0 +1,20 @@
INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/eigen3)
ExternalProject_Add(
eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
URL_MD5 "1a47e78efe365a97de0c022d127607c3"
PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
LIST(APPEND external_project_dependencies eigen3)

@ -77,6 +77,15 @@
# /cmake/external/*.cmake:
#
# cc_test(example_test SRCS example_test.cc DEPS example glog gflags)
#
# To build a go static library using Golang, use the go_ prefixed version:
#
# go_library(example STATIC)
#
# To build a go shared library using Golang, use the go_ prefixed version:
#
# go_library(example SHARED)
#
if(NOT APPLE)
find_package(Threads REQUIRED)
@ -246,42 +255,53 @@ endfunction(nv_test)
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle/Paddle")
# Because api.go defines a GO wrapper to ops and tensor, it depends on
# both. This implies that if any of tensor.{h,cc}, ops.{h,cu}, or
# api.go is changed, api need to be re-built.
# go_library(api
# SRCS
# api.go
# DEPS
# tensor # Because ops depend on tensor, this line is optional.
# ops)
function(go_library TARGET_NAME)
set(options OPTIONAL)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
set(multiValueArgs DEPS)
cmake_parse_arguments(go_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if (${go_library_OPTIONAL} STREQUAL "SHARED")
if (go_library_SHARED OR go_library_shared)
set(BUILD_MODE "-buildmode=c-shared")
if(APPLE)
set(LIB_NAME "lib${TARGET_NAME}.dylib")
else()
set(LIB_NAME "lib${TARGET_NAME}.so")
endif()
set(LIB_NAME "${CMAKE_SHARED_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX}")
else()
set(BUILD_MODE "-buildmode=c-archive")
set(LIB_NAME "lib${TARGET_NAME}.a")
set(LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}")
endif()
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
# Add dummy code to support `make target_name` under Terminal Command
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
if (go_library_SHARED OR go_library_shared)
add_library(${TARGET_NAME} SHARED ${dummyfile})
else()
add_library(${TARGET_NAME} STATIC ${dummyfile})
endif()
if(go_library_DEPS)
add_dependencies(${TARGET_NAME} ${go_library_DEPS})
endif(go_library_DEPS)
# we need to symlink Paddle directory into GOPATH. If we
# don't do it and we have code that depends on Paddle, go
# get ./... will download a new Paddle repo from Github,
# without the changes in our current Paddle repo that we
# want to build.
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
# Symlink Paddle directory into GOPATH
COMMAND mkdir -p ${PADDLE_IN_GOPATH}
COMMAND rm -rf ${PADDLE_IN_GOPATH}
COMMAND ln -sf ${CMAKE_SOURCE_DIR} ${PADDLE_IN_GOPATH}
# Automatically get all dependencies specified in the source code
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ./...
# Golang build source code
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${go_library_SRCS}
${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME}_lib ALL DEPENDS ${TARGET_NAME}_timestamp ${go_library_DEPS})
add_library(${TARGET_NAME} STATIC IMPORTED)
set_property(TARGET ${TARGET_NAME} PROPERTY
IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}")
add_dependencies(${TARGET_NAME} ${TARGET_NAME}_lib)
endfunction(go_library)
function(go_binary TARGET_NAME)
@ -311,10 +331,3 @@ function(go_test TARGET_NAME)
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test)
# go_extern will download extern go project.
# go_extern(target_name extern_source)
# go_extern(go_redis github.com/hoisie/redis)
function(go_extern TARGET_NAME)
add_custom_target(${TARGET_NAME} env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get ${ARGN})
endfunction(go_extern)

@ -33,6 +33,7 @@ ELSE(WIN32)
SET(CMAKE_OSX_DEPLOYMENT_TARGET ${MACOS_VERSION} CACHE STRING
"Minimum OS X version to target for deployment (at runtime); newer APIs weak linked. Set to empty string for default value.")
ENDIF()
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
ELSE(APPLE)
IF(EXISTS "/etc/issue")

@ -0,0 +1,124 @@
# Design of Scope in Paddle
## Overview
Scope is an important concept in programming languages, which defines a program region that a set of bindings between names and entities applies. In a specific scope, a valid name is uniquely associated with an entity, such as a variable. And in another scope, this name may refer to other entity or nothing at all. It clearly restricts the visibility and validity of names in a program. Hence **Scope** is introduced to PaddlePaddle to manage variables in context. But different from the original abstract concept, Scope now becomes an object with two important attributes:
- Scope is an association of a name to variable.
- Variables in a parent scope can be retrieved from local scope.
A detailed explanation of these two attributes goes as following.
## Scope is an association of a name to variable.
Scope is an association of a name to variable. All variables belong to `Scope`. You need to specify a scope to run a Net, i.e., `net.Run(&scope)`. One net can run in different scopes and update different variable in the scope.
1. Scope only contains a map of a name to variable.
All parameters, data, states in a Net should be variables and stored inside a scope. Each op should get inputs and outputs to do computation from a scope, such as data buffer, state(momentum) etc.
1. Variable can only be created by Scope and a variable can only be got from Scope. User cannot create or get a variable outside a scope. This is a constraints of our framework, and will keep our framework simple and clear.
1. Scope only contains methods that are used to Create and Get Variables. Scope do not contain Operators and have no information to run them.
`Net` is designed to drive the computation and Scope only contains a map of variables. There is no computation logic inside a `Scope`. Scope just handles the lifetime management of variables.
- `Create` is used to create a Variable by its name and add the mapping relation.
- `Get` is used to find a Variable by name.
1. Every variable only belongs to one certain Scope.
Variable can not belong to many scopes. If you want to use variables from parent scope, you can use `parent scope`.
1. Scope should destruct all Variables inside it when itself is destructed. User can never store `Variable` pointer somewhere else.
Because Variable can only be got from Scope. When destroying Scope, we also need to destroy all the Variables in it. If user store `Variable` pointer to private data member or some global variable, the pointer will be a invalid pointer when associated `Scope` is destroyed.
```cpp
class Scope {
public:
Variable* CreateVariable(const std::string& name);
const Variable* GetVariable(const std::string& name) const;
private:
std::unordered_map<std::string, std::unique_ptr<Vairable>> vars_;
};
```
## Parent scope and local scope
Just like [scope](https://en.wikipedia.org/wiki/Scope_(computer_science)) in programming languages, `Scope` in the neural network can also be a local scope. There are two attributes about local scope.
1. We can create local variables in a local scope. When that local scope are destroyed, all local variables should also be destroyed.
2. Variables in a parent scope can be retrieved from local scopes of that parent scope, i.e., when user get a variable from a scope, it will try to search this variable in current scope. If there is no such variable in the local scope, `scope` will keep searching from its parent, until the variable is found or there is no parent.
```cpp
class Scope {
public:
Scope(const std::shared_ptr<Scope>& scope): parent_(scope) {}
Variable* GetVariable(const std::string& name) const {
Variable* var = GetVarLocally(name);
if (var != nullptr) {
return var;
} else if (parent_ != nullptr) {
return parent_->GetVariable(name);
} else {
return nullptr;
}
}
private:
std::shared_ptr<Scope> parent_ {nullptr};
};
```
In `Scope` class, there is a private data member called `parent_`. `parent_` is a smart pointer to its parent scope. When user `Get` a variable by its `name`, the `name` will be searched inside the current scope. If the variable cannot be found locally and parent scope is not a `nullptr`, the variable will be searched inside that parent scope. `parent_` pointer's default value is `nullptr`. It means that the scope is a global scope when `parent_` is nullptr.
A local scope is very useful when we implement Recurrent Neural Network. Each timestep of an RNN should be a `Net`. Each `Net` of timestep (`StepNet` for short) should use an independent local scope. Just like variables in a while loop is inside a local scope in programming languages. By using a single `StepNet` and changing local scope, we can implement an RNN easily.
# Interface Design
```cpp
class Variable {
private:
Variable() = default;
friend class Scope;
};
class Scope {
private:
Scope(const std::shared_ptr<Scope>& parent = nullptr);
public:
static std::shared_ptr<Scope> Create(const std::shared_ptr<Scope>& parent = nullptr);
// return nullptr if not found.
Variable* GetVariable(const std::string& name) const;
// return Error if already contains same name variable.
Error CreateVariable(const std::string& name);
private:
std::shared_ptr<Scope> parent_;
std::unordered_map<std::string, std::unique_ptr<Variable>> vars_;
};
```
## Only scope can create a variable
To ensure `only scope can create a variable`, we should mark `Variable`'s constructor as a private member function, and Scope is a friend class of Variable. And then only `CreateVariable` can construct `Variable`.
## When scope destroyed, all variables inside this scope should be destroyed together
The scope hold unique pointers for all variables. User can `GetVariable` from scope, but he should not hold this pointer as a member variable. Because when scope is destroyed, all variables inside this scope will be destroyed together.
## Sharing a parent scope
Local scope contains a `parent_` pointer. It is a linked-list for scopes. Using a `shared_ptr` because when a local scope is using, its parents cannot be destroyed.
Also, as the parent scope is a `shared_ptr`, we can only `Create()` a scope shared pointer. We cannot construct a scope variable, because it cannot be passed to other scope as `parent` pointer.
## Orthogonal interface
`GetVariable` will return `nullptr` when `name` is not found. It can be used as `Contains` method. `CreateVariable` will return a `Error` when there is a name conflict locally. Combine `GetVariable` and `CreateVariable`, we can implement `CreateOrGetVariable` easily.

@ -111,7 +111,7 @@ PaddlePaddle支持不同类型的输入数据主要包括四种类型
# define training dataset reader
def train_reader():
train_x = np.array([[1, 1], [1, 2], [3, 4], [5, 2]])
train_y = np.array([-2, -3, -7, -7])
train_y = np.array([[-2], [-3], [-7], [-7]])
def reader():
for i in xrange(train_y.shape[0]):
yield train_x[i], train_y[i]

@ -1,44 +0,0 @@
if(NOT CMAKE_Go_COMPILER)
if(NOT $ENV{GO_COMPILER} STREQUAL "")
get_filename_component(CMAKE_Go_COMPILER_INIT $ENV{GO_COMPILER} PROGRAM PROGRAM_ARGS CMAKE_Go_FLAGS_ENV_INIT)
if(CMAKE_Go_FLAGS_ENV_INIT)
set(CMAKE_Go_COMPILER_ARG1 "${CMAKE_Go_FLAGS_ENV_INIT}" CACHE STRING "First argument to Go compiler")
endif()
if(NOT EXISTS ${CMAKE_Go_COMPILER_INIT})
message(SEND_ERROR "Could not find compiler set in environment variable GO_COMPILER:\n$ENV{GO_COMPILER}.")
endif()
endif()
set(Go_BIN_PATH
$ENV{GOPATH}
$ENV{GOROOT}
$ENV{GOROOT}/../bin
$ENV{GO_COMPILER}
/usr/bin
/usr/local/bin
)
if(CMAKE_Go_COMPILER_INIT)
set(CMAKE_Go_COMPILER ${CMAKE_Go_COMPILER_INIT} CACHE PATH "Go Compiler")
else()
find_program(CMAKE_Go_COMPILER
NAMES go
PATHS ${Go_BIN_PATH}
)
EXEC_PROGRAM(${CMAKE_Go_COMPILER} ARGS version OUTPUT_VARIABLE GOLANG_VERSION)
STRING(REGEX MATCH "go[0-9]+.[0-9]+.[0-9]+[ /A-Za-z0-9]*" VERSION "${GOLANG_VERSION}")
message("-- The Golang compiler identification is ${VERSION}")
message("-- Check for working Golang compiler: ${CMAKE_Go_COMPILER}")
endif()
endif()
mark_as_advanced(CMAKE_Go_COMPILER)
configure_file(${CMAKE_MODULE_PATH}/CMakeGoCompiler.cmake.in
${CMAKE_PLATFORM_INFO_DIR}/CMakeGoCompiler.cmake @ONLY)
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")

@ -1,8 +0,0 @@
set(CMAKE_Go_COMPILER "@CMAKE_Go_COMPILER@")
set(CMAKE_Go_COMPILER_LOADED 1)
set(CMAKE_Go_SOURCE_FILE_EXTENSIONS go)
set(CMAKE_Go_LINKER_PREFERENCE 40)
set(CMAKE_Go_OUTPUT_EXTENSION .o)
set(CMAKE_Go_OUTPUT_EXTENSION_REPLACE 1)
set(CMAKE_Go_COMPILER_ENV_VAR "GO_COMPILER")

@ -1,7 +0,0 @@
if(NOT CMAKE_Go_COMPILE_OBJECT)
set(CMAKE_Go_COMPILE_OBJECT "go tool compile -l -N -o <OBJECT> <SOURCE> ")
endif()
if(NOT CMAKE_Go_LINK_EXECUTABLE)
set(CMAKE_Go_LINK_EXECUTABLE "go tool link -o <TARGET> <OBJECTS> ")
endif()

@ -1 +0,0 @@
set(CMAKE_Go_COMPILER_WORKS 1 CACHE INTERNAL "")

@ -1,45 +0,0 @@
# Setting Paddle Compile Flags
include(CheckCXXCompilerFlag)
include(CheckCCompilerFlag)
include(CheckCXXSymbolExists)
include(CheckTypeSize)
function(CheckCompilerCXX11Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
# https://gist.github.com/yamaya/2924292
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1)
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
endif()
else()
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
endif()
endif()
endif()
endfunction()
CheckCompilerCXX11Flag()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
# Common gpu architectures: Kepler, Maxwell
foreach(capability 30 35 50)
list(APPEND __arch_flags " -gencode arch=compute_${capability},code=sm_${capability}")
endforeach()
if (CUDA_VERSION VERSION_GREATER "7.0" OR CUDA_VERSION VERSION_EQUAL "7.0")
list(APPEND __arch_flags " -gencode arch=compute_52,code=sm_52")
endif()
# Modern gpu architectures: Pascal
if (CUDA_VERSION VERSION_GREATER "8.0" OR CUDA_VERSION VERSION_EQUAL "8.0")
list(APPEND __arch_flags " -gencode arch=compute_60,code=sm_60")
endif()
set(CUDA_NVCC_FLAGS ${__arch_flags} ${CUDA_NVCC_FLAGS})

@ -1,48 +0,0 @@
set(GOPATH "${CMAKE_CURRENT_BINARY_DIR}/go")
file(MAKE_DIRECTORY ${GOPATH})
set(PADDLE_IN_GOPATH "${GOPATH}/src/github.com/PaddlePaddle")
file(MAKE_DIRECTORY ${PADDLE_IN_GOPATH})
function(GO_LIBRARY NAME BUILD_TYPE)
if(BUILD_TYPE STREQUAL "STATIC")
set(BUILD_MODE -buildmode=c-archive)
set(LIB_NAME "lib${NAME}.a")
else()
set(BUILD_MODE -buildmode=c-shared)
if(APPLE)
set(LIB_NAME "lib${NAME}.dylib")
else()
set(LIB_NAME "lib${NAME}.so")
endif()
endif()
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
file(RELATIVE_PATH rel ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR})
# find Paddle directory.
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
get_filename_component(PADDLE_DIR ${PARENT_DIR} DIRECTORY)
# automatically get all dependencies specified in the source code
# for given target.
add_custom_target(${NAME}_goGet env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} get -d ${rel}/...)
# make a symlink that references Paddle inside $GOPATH, so go get
# will use the local changes in Paddle rather than checkout Paddle
# in github.
add_custom_target(${NAME}_copyPaddle
COMMAND rm -rf ${PADDLE_IN_GOPATH}/Paddle
COMMAND ln -sf ${PADDLE_DIR} ${PADDLE_IN_GOPATH}/Paddle)
add_dependencies(${NAME}_goGet ${NAME}_copyPaddle)
add_custom_command(OUTPUT ${OUTPUT_DIR}/.timestamp
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} ${NAME}_goGet)
endfunction(GO_LIBRARY)

@ -30,7 +30,13 @@ func main() {
log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout)
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err := e.Register()
if err != nil {
panic(err)
}
s, err := pserver.NewService(idx)
if err != nil {
panic(err)
}

@ -13,10 +13,13 @@ typedef int paddle_master_client;
import "C"
import (
"strings"
"sync"
"time"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client {
return h
}
type addresser string
func (a addresser) Address() string {
return string(a)
//export paddle_new_etcd_master_client
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
}
//export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize)
ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
return add(c)
}

@ -2,18 +2,12 @@ package master
import (
"os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus"
)
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
@ -24,11 +18,11 @@ type Client struct {
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func NewClient(addr Addresser, bufSize int) *Client {
func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr)
go c.monitorMaster(addrCh)
go c.getRecords()
return c
}
@ -72,12 +66,10 @@ func (c *Client) getRecords() {
}
}
func (c *Client) monitorMaster(addr Addresser) {
func (c *Client) monitorMaster(addrCh <-chan string) {
lastMaster := ""
monitor := func() {
// get the lastest address of the master server,
for curMaster := range addrCh {
// connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster {
if curMaster == "" {
err := c.conn.Close()
@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) {
// to retry next time.
curMaster = lastMaster
}
}
}
lastMaster = curMaster
}
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
}
// SetDataset set dataset for the master server to dispatch.

@ -26,12 +26,6 @@ func init() {
log.SetLevel(log.ErrorLevel)
}
type TestAddresser string
func (a TestAddresser) Address() string {
return string(a)
}
func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0"
@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) {
if err != nil {
panic(err)
}
go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) {
// Manually intialize client to avoid calling c.getRecords()
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p)))
addr := fmt.Sprintf(":%d", p)
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
c.SetDataset([]string{path})
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {

@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) {
path = "/tmp/master_client_TestFull"
total = 50
)
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) {
if err != nil {
panic(err)
}
go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
if err != nil {
@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) {
}
w.Close()
f.Close()
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10)
curAddr := make(chan string, 1)
curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {

@ -18,8 +18,8 @@ const (
DefaultAddrPath = "/master/addr"
)
// EtcdClient is the etcd client that master uses for fault tolerance
// and service registry.
// EtcdClient is the etcd client that the master uses for fault
// tolerance and service registry.
type EtcdClient struct {
lockPath string
statePath string
@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) {
state := kvs[0].Value
return state, nil
}
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
return "", err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return "", nil
}
v := kvs[0].Value
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value)
valChan <- string(ev.Kv.Value)
}
}
}

@ -1,14 +1,3 @@
cmake_minimum_required(VERSION 3.0)
get_filename_component(PARENT_DIR ${CMAKE_CURRENT_SOURCE_DIR} DIRECTORY)
get_filename_component(PARENT_DIR ${PARENT_DIR} DIRECTORY)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${PARENT_DIR}/cmake")
project(cxx_go C Go)
include(golang)
include(flags)
go_library(paddle_pserver_cclient STATIC)
add_subdirectory(test)

@ -1,22 +1,3 @@
cmake_minimum_required(VERSION 3.0)
add_executable(main main.c)
add_dependencies(main paddle_pserver_cclient)
add_executable(test_cclient test_cclient.c)
add_dependencies(test_cclient paddle_pserver_cclient)
if(APPLE)
set(CMAKE_EXE_LINKER_FLAGS "-framework CoreFoundation -framework Security")
else()
set(CMAKE_EXE_LINKER_FLAGS "-pthread")
endif()
if(PROJ_ROOT)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/..)
target_link_libraries(main ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread)
target_link_libraries(test_cclient ${CMAKE_CURRENT_BINARY_DIR}/../libpaddle_pserver_cclient.a pthread)
else(PROJ_ROOT)
include_directories(${CMAKE_BINARY_DIR})
target_link_libraries(main ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
target_link_libraries(test_cclient ${CMAKE_BINARY_DIR}/libpaddle_pserver_cclient.a pthread)
endif(PROJ_ROOT)
cc_library(main SRCS main.c DEPS paddle_pserver_cclient)
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient)

@ -1,6 +1,7 @@
package pserver
import (
"errors"
"hash/fnv"
"sort"
"time"
@ -123,6 +124,9 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating
// parameters.
func (c *Client) SendGrads(grads []Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g Gradient) {

@ -7,7 +7,6 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
@ -31,7 +30,7 @@ func init() {
port[i] = p
go func(l net.Listener) {
s, err := pserver.NewService("", time.Second*5)
s, err := pserver.NewService(0)
if err != nil {
panic(err)
}

@ -0,0 +1,181 @@
package pserver
import (
"context"
"errors"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
numPservers int
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
}
// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
return &EtcdClient{
etcdTimeout: timeout,
numPservers: numPservers,
etcdEndpoints: endpoints,
}
}
// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) {
var err error
e.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return 0, err
}
// initialize connection to etcd.
ep := strings.Split(e.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: e.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(e.etcdTimeout)
continue
}
e.etcdClient = cli
log.Debugf("inited client to %s", e.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPsercers(ctx, e.numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := e.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(e.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(e.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
var pserverIdx int
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
var err error
pserverIdx, err = e.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(e.etcdTimeout)
continue
}
break
}
return pserverIdx, nil
}
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
var idx int
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < e.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := e.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
idx = i
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
if err != nil {
return 0, err
}
return idx, nil
}

@ -1,18 +1,9 @@
package pserver
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
// ElementType is the type of elements of a Parameter.
@ -55,160 +46,25 @@ type Gradient Parameter
// Service is the RPC service for pserver.
type Service struct {
initialized chan struct{}
idx int
mu sync.Mutex
opt *optimizer
paramMap map[string]Parameter
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)}
func NewService(idx int) (*Service, error) {
s := &Service{
idx: idx,
opt: newOptimizer(sgd, 0.005),
}
s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{})
s.etcdEndpoints = endpoints
s.etcdTimeout = timeout
var err error
s.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return nil, err
}
if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.initDesiredPsercers(ctx, numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(s.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil
}
func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// Eat the keep alive message so etcd
// will not expire the lease.
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
log.Debug("register finished")
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save