Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into dev_add_tensor_copy

cblas_new
fengjiayi 8 years ago
commit dfa4650920

3
.gitignore vendored

@ -22,3 +22,6 @@ cmake-build-*
# generated while compiling # generated while compiling
python/paddle/v2/framework/core.so python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake

@ -28,7 +28,9 @@ if(NOT CMAKE_CROSSCOMPILING)
endif(NOT CMAKE_CROSSCOMPILING) endif(NOT CMAKE_CROSSCOMPILING)
find_package(Git REQUIRED) find_package(Git REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
find_package(Boost QUIET) if(NOT ANDROID)
find_package(Boost QUIET)
endif()
include(simd) include(simd)
@ -140,6 +142,10 @@ endif(USE_NNPACK)
add_subdirectory(proto) add_subdirectory(proto)
# "add_subdirectory(go)" should be placed after the following loine,
# because it depends on paddle/optimizer.
add_subdirectory(paddle/optimizer)
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be # "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
# placed after this block, because they depends on it. # placed after this block, because they depends on it.
if(WITH_GOLANG) if(WITH_GOLANG)
@ -147,7 +153,9 @@ if(WITH_GOLANG)
endif(WITH_GOLANG) endif(WITH_GOLANG)
add_subdirectory(paddle) add_subdirectory(paddle)
add_subdirectory(python) if(WITH_PYTHON)
add_subdirectory(python)
endif()
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
endif() endif()

@ -106,6 +106,9 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
SET(CMAKE_SYSTEM_PROCESSOR armv7-a) SET(CMAKE_SYSTEM_PROCESSOR armv7-a)
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(ANDROID_TOOLCHAIN_NAME aarch64-linux-android)
ENDIF()
SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-") SET(ANDROID_TOOLCHAIN_PREFIX "${ANDROID_TOOLCHAIN_ROOT}/bin/${ANDROID_TOOLCHAIN_NAME}-")
ENDIF() ENDIF()
@ -162,6 +165,10 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF() ENDIF()
ENDIF() ENDIF()
IF(ANDROID_ABI STREQUAL "arm64-v8a")
LIST(APPEND ANDROID_COMPILER_FLAGS -march=armv8-a)
ENDIF()
STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}") STRING(REPLACE ";" " " ANDROID_COMPILER_FLAGS "${ANDROID_COMPILER_FLAGS}")
STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}") STRING(REPLACE ";" " " ANDROID_LINKER_FLAGS "${ANDROID_LINKER_FLAGS}")

@ -52,6 +52,7 @@ ExternalProject_Add(
ADD_LIBRARY(glog STATIC IMPORTED GLOBAL) ADD_LIBRARY(glog STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES}) SET_PROPERTY(TARGET glog PROPERTY IMPORTED_LOCATION ${GLOG_LIBRARIES})
ADD_DEPENDENCIES(glog extern_glog) ADD_DEPENDENCIES(glog extern_glog gflags)
LINK_LIBRARIES(glog gflags)
LIST(APPEND external_project_dependencies glog) LIST(APPEND external_project_dependencies glog)

@ -32,7 +32,12 @@ IF(NOT ${CBLAS_FOUND})
# arm_soft_fp_abi branch of OpenBLAS to support softfp # arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi # https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5") SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0) IF(ANDROID_ABI MATCHES "^armeabi(-v7a)?$")
SET(TARGET "ARMV7")
ELSEIF(ANDROID_ABI STREQUAL "arm64-v8a")
SET(TARGET "ARMV8")
ENDIF()
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=${TARGET} ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(RPI) ELSEIF(RPI)
# use hardfp # use hardfp
SET(OPENBLAS_COMMIT "v0.2.19") SET(OPENBLAS_COMMIT "v0.2.19")

@ -90,11 +90,11 @@
# including binary directory for generated headers. # including binary directory for generated headers.
include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE AND NOT ANDROID)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl") set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt")
endif(NOT APPLE) endif(NOT APPLE AND NOT ANDROID)
function(merge_static_libs TARGET_NAME) function(merge_static_libs TARGET_NAME)
set(libs ${ARGN}) set(libs ${ARGN})
@ -301,7 +301,7 @@ function(go_library TARGET_NAME)
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${${TARGET_NAME}_LIB_PATH}" COMMAND rm "${${TARGET_NAME}_LIB_PATH}"
# Golang build source code # Golang build source code
@ -309,7 +309,7 @@ function(go_library TARGET_NAME)
-o "${${TARGET_NAME}_LIB_PATH}" -o "${${TARGET_NAME}_LIB_PATH}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}"
# must run under GOPATH # must run under GOPATH
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_dependencies(${TARGET_NAME} go_vendor) add_dependencies(${TARGET_NAME} go_vendor)
endfunction(go_library) endfunction(go_library)
@ -320,14 +320,11 @@ function(go_binary TARGET_NAME)
cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
# TODO: don't know what ${TARGET_NAME}_link does
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin) install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin)
endfunction(go_binary) endfunction(go_binary)
@ -335,15 +332,18 @@ endfunction(go_binary)
function(go_test TARGET_NAME) function(go_test TARGET_NAME)
set(options OPTIONAL) set(options OPTIONAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs DEPS)
cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_test_SRCS} ".${CMAKE_CURRENT_SOURCE_REL_DIR}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test) endfunction(go_test)
function(proto_library TARGET_NAME) function(proto_library TARGET_NAME)

@ -17,3 +17,7 @@ add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver) add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master) add_subdirectory(cmd/master)
add_subdirectory(master/c) add_subdirectory(master/c)
add_subdirectory(master)
add_subdirectory(pserver)
add_subdirectory(pserver/client)
add_subdirectory(utils/networkhelper)

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
go_binary(master SRC master.go DEPS paddle_go_optimizer) go_binary(master SRC master.go)

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -18,53 +19,47 @@ func main() {
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
level, err := log.ParseLevel(*logLevel) level, err := log.ParseLevel(*logLevel)
if err != nil { candy.Must(err)
panic(err)
}
log.SetLevel(level) log.SetLevel(level)
var idx int var idx int
var cp pserver.Checkpoint var cp pserver.Checkpoint
var e *pserver.EtcdClient var e *pserver.EtcdClient
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
timeout := time.Second * time.Duration((*etcdTimeout)) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout)
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register() idx, err = e.Register()
candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
if err != nil { if err != nil {
panic(err) log.Errorf("Fetch checkpoint failed, %s", err)
} }
} }
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil { candy.Must(err)
panic(err)
}
err = rpc.Register(s) err = rpc.Register(s)
if err != nil { candy.Must(err)
panic(err)
}
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { candy.Must(err)
panic(err)
}
log.Infof("start pserver at port %d", *port) log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil) err = http.Serve(l, nil)
candy.Must(err)
if err != nil {
panic(err)
}
} }

14
go/glide.lock generated

@ -1,8 +1,8 @@
hash: b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7 hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
updated: 2017-06-27T14:05:48.925262819+08:00 updated: 2017-07-11T10:04:40.786745417+08:00
imports: imports:
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: 61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7 version: cb2a496c4ddd1c87a9f280e116649b599999ec79
subpackages: subpackages:
- auth/authpb - auth/authpb
- clientv3 - clientv3
@ -22,7 +22,9 @@ imports:
- name: github.com/PaddlePaddle/recordio - name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129 version: edfb82af0739c84f241c87390ec5649c7b28c129
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: golang.org/x/net - name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2 version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages: subpackages:
@ -34,11 +36,11 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733 version: abf9c25f54453410d0c6668e519582a9e1115027
subpackages: subpackages:
- unix - unix
- name: golang.org/x/text - name: golang.org/x/text
version: 4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77 version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
subpackages: subpackages:
- secure/bidirule - secure/bidirule
- transform - transform

@ -10,3 +10,4 @@ import:
version: ^1.7.4-pre version: ^1.7.4-pre
- package: github.com/sirupsen/logrus - package: github.com/sirupsen/logrus
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(master_test)
endif()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer)
endif()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer)
endif()

@ -0,0 +1 @@
libpaddle_go_optimizer.a

@ -1,5 +1,13 @@
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m) target_link_libraries(paddle_go_optimizer stdc++ m)
# Copy library to the required place.
# See: go/pserver/optimizer.go:
# // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
add_custom_command(TARGET paddle_go_optimizer POST_BUILD
COMMAND cp "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_go_optimizer.a" "${CMAKE_CURRENT_SOURCE_DIR}"
)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING) if(WITH_TESTING)
# FIXME: this test requires pserver which is not managed by the test # FIXME: this test requires pserver which is not managed by the test

@ -16,7 +16,7 @@ import (
const ( const (
// PsDesired is etcd path for store desired pserver count // PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired" PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr // PsPath is the base dir for pserver to store their addr
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information // PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/" PsCheckpoint = "/checkpoints/"
@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil return idx, nil
} }
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.etcdClient.Get(ctx, key)
cancel()
if err != nil {
return []byte{}, err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return []byte{}, nil
}
v := kvs[0].Value
return v, nil
}
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.etcdClient.Put(ctx, key, string(value)) _, err := e.etcdClient.Put(ctx, key, string(value))
cancel() cancel()
if err != nil { if err != nil {

@ -1,8 +1,7 @@
package pserver package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
// //FIXME: ldflags contain "build" path // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #cgo LDFLAGS: ${SRCDIR}/../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #include "paddle/optimizer/optimizer.h" // #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h> // #include <stdlib.h>
// #include <string.h> // #include <string.h>

@ -9,6 +9,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@ -21,14 +22,14 @@ import (
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
// RPC error message.
const ( const (
// AlreadyInitialized is true if pserver is initialized AlreadyInitialized = "pserver already initialized"
AlreadyInitialized = "pserver already initialized" Uninitialized = "pserver not fully initialized"
// Uninitialized is true if pserver not fully initialized CheckpointMD5Failed = "checkpoint file MD5 validation failed"
Uninitialized = "pserver not fully initialized"
) )
// Supported element types // Supported element types.
const ( const (
Int32 ElementType = iota Int32 ElementType = iota
UInt32 UInt32
@ -51,21 +52,15 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format Config []byte // parameter configuration in Proto Buffer format
} }
// ParameterCheckpoint is Parameter and State checkpoint // checkpointMeta saves checkpoint metadata
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
type checkpointMeta struct { type checkpointMeta struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Md5sum string `json:"md5sum"` MD5 string `json:"md5"`
Timestamp string `json:"timestamp"` Timestamp int64 `json:"timestamp"`
} }
// Checkpoint is the pserver shard persist in file // Checkpoint is the pserver shard persist in file
type Checkpoint []ParameterCheckpoint type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter type Gradient Parameter
@ -81,12 +76,53 @@ type Service struct {
optMap map[string]*optimizer optMap map[string]*optimizer
} }
// parameterCheckpoint saves parameter checkpoint
type parameterCheckpoint struct {
ParameterWithConfig
State []byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil {
return nil, err
}
var cpMeta checkpointMeta
if err = json.Unmarshal(v, &cpMeta); err != nil {
return nil, err
}
fn := filepath.Join(cpPath, cpMeta.UUID)
if _, err = os.Stat(fn); os.IsNotExist(err) {
return nil, err
}
content, err := ioutil.ReadFile(fn)
if err != nil {
return nil, err
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
return nil, errors.New(CheckpointMD5Failed)
}
dec := gob.NewDecoder(bytes.NewReader(content))
cp := Checkpoint{}
if err = dec.Decode(cp); err != nil {
return nil, err
}
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: time.Second * time.Duration(seconds), checkpointInterval: interval,
checkpointPath: path, checkpointPath: path,
client: client, client: client,
} }
@ -95,9 +131,11 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp
if cp != nil { if cp != nil {
for _, item := range cp { for _, item := range cp {
p := item.ParamConfig p := ParameterWithConfig{
st := item.State Param: item.Param,
s.optMap[p.Param.Name] = newOptimizer(p, st) Config: item.Config,
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
} }
} }
return s, nil return s, nil
@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
cp := make([]ParameterCheckpoint, 0, len(s.optMap)) cp := make([]parameterCheckpoint, len(s.optMap))
index := 0 index := 0
for name, opt := range s.optMap { for name, opt := range s.optMap {
var pc ParameterCheckpoint var pc parameterCheckpoint
pc.ParamConfig.Param.Name = name pc.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType pc.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights() pc.Param.Content = opt.GetWeights()
pc.State = opt.GetStates() pc.State = opt.GetStates()
cp[index] = pc cp[index] = pc
index++ index++
@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error {
cpMeta := checkpointMeta{} cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String() cpMeta.Timestamp = time.Now().UnixNano()
h := md5.New() h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, _ := json.Marshal(cpMeta) cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
if err != nil { if err != nil {
return err return err
} }
@ -219,7 +257,11 @@ func (s *Service) doCheckpoint() error {
log.Info("checkpoint does not exists.") log.Info("checkpoint does not exists.")
} else { } else {
err = os.Remove(cpMeta.UUID) err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) if err != nil {
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
} else {
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
} }
f, err := os.Create(cpMeta.UUID) f, err := os.Create(cpMeta.UUID)
defer f.Close() defer f.Close()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(network_helper_test)
endif()

@ -8,7 +8,6 @@ add_subdirectory(gserver)
add_subdirectory(pserver) add_subdirectory(pserver)
add_subdirectory(trainer) add_subdirectory(trainer)
add_subdirectory(scripts) add_subdirectory(scripts)
add_subdirectory(optimizer)
add_subdirectory(string) add_subdirectory(string)
if(Boost_FOUND) if(Boost_FOUND)

@ -11,7 +11,7 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc device_context)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <type_traits>
#include "paddle/framework/attr_checker.h" #include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/op_proto.pb.h"
@ -81,8 +82,6 @@ class OpProtoAndCheckerMaker {
return op_checker_->AddAttrChecker<T>(name); return op_checker_->AddAttrChecker<T>(name);
} }
void AddType(const std::string& op_type) { proto_->set_type(op_type); }
void AddComment(const std::string& comment) { void AddComment(const std::string& comment) {
*(proto_->mutable_comment()) = comment; *(proto_->mutable_comment()) = comment;
} }
@ -101,8 +100,11 @@ class OpRegistry {
OpProto& op_proto = protos()[op_type]; OpProto& op_proto = protos()[op_type];
OpAttrChecker& op_checker = op_checkers()[op_type]; OpAttrChecker& op_checker = op_checkers()[op_type];
ProtoMakerType(&op_proto, &op_checker); ProtoMakerType(&op_proto, &op_checker);
PADDLE_ENFORCE(op_proto.IsInitialized(), *op_proto.mutable_type() = op_type;
"Fail to initialize %s's OpProto !", op_type); PADDLE_ENFORCE(
op_proto.IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_proto.InitializationErrorString());
} }
static OperatorBase* CreateOp(const OpDesc& op_desc) { static OperatorBase* CreateOp(const OpDesc& op_desc) {
@ -119,6 +121,7 @@ class OpRegistry {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr); op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
} }
op_checkers().at(op_type).Check(op->attrs_); op_checkers().at(op_type).Check(op->attrs_);
op->Init();
return op; return op;
} }
@ -142,18 +145,73 @@ class OpRegistry {
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
class OpRegisterHelper { class OpRegisterHelper {
public: public:
OpRegisterHelper(std::string op_type) { OpRegisterHelper(const char* op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type); OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
} }
}; };
#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \ #define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
class __op_class##Register { \ struct __test_global_namespace_##uniq_name##__ {}; \
private: \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \ __test_global_namespace_##uniq_name##__>::value, \
}; \ msg)
const OpRegisterHelper<__op_class, __op_maker_class> \
__op_class##Register::reg(#__op_type); #define REGISTER_OP(__op_type, __op_class, __op_maker_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \
"REGISTER_OP must be in global namespace"); \
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
__op_register_##__op_type##__(#__op_type); \
int __op_register_##__op_type##_handle__() { return 0; }
#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##GPU_OR_CPU##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; }
#define REGISTER_OP_GPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)
#define REGISTER_OP_CPU_KERNEL(type, KernelType) \
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)
#define USE_OP_WITHOUT_KERNEL(op_type) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_without_kernel_##op_type, \
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \
extern int __op_register_##op_type##_handle__(); \
static int __use_op_ptr_##op_type##_without_kernel__ \
__attribute__((unused)) = __op_register_##op_type##_handle__()
#define USE_OP_KERNEL(op_type, DEVICE_TYPE) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
"USE_OP_KERNEL must be in global namespace"); \
extern int __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__(); \
static int __use_op_ptr_##op_type##_##DEVICE_TYPE##_kernel__ \
__attribute__((unused)) = \
__op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__()
#ifdef PADDLE_ONLY_CPU
#define USE_OP(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU);
#else
#define USE_OP(op_type) \
USE_OP_WITHOUT_KERNEL(op_type); \
USE_OP_KERNEL(op_type, CPU); \
USE_OP_KERNEL(op_type, GPU)
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -1,17 +1,13 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"
using namespace paddle::framework;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorWithKernel { class CosineOp : public OperatorBase {
public: public:
void Run(const OpRunContext* context) const override { void Run(const std::shared_ptr<Scope>& scope,
printf("%s\n", DebugString().c_str()); const platform::DeviceContext& dev_ctx) const override {}
} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@ -23,19 +19,17 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.LargerThan(0.0); .LargerThan(0.0);
AddType("cos");
AddComment("This is cos op"); AddComment("This is cos op");
} }
}; };
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
class MyTestOp : public OperatorWithKernel {
public: public:
void Run(const OpRunContext* ctx) const override {
printf("%s\n", DebugString().c_str());
printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
}
}; };
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@ -49,15 +43,17 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
}; };
AddAttr<int>("test_attr", "a simple test attribute") AddAttr<int>("test_attr", "a simple test attribute")
.AddCustomChecker(my_checker); .AddCustomChecker(my_checker);
AddType("my_test_op");
AddComment("This is my_test op"); AddComment("This is my_test op");
} }
}; };
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_OP(cos_sim, paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker);
REGISTER_OP(my_test_op, paddle::framework::MyTestOp,
paddle::framework::MyTestOpProtoAndCheckerMaker);
TEST(OpRegistry, CreateOp) { TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc; paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim"); op_desc.set_type("cos_sim");
@ -72,9 +68,9 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale"); float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(scale_get, scale); ASSERT_EQ(scale_get, scale);
} }
@ -115,9 +111,9 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
} }
@ -169,14 +165,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(4); attr->set_i(4);
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<paddle::framework::Scope>();
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr"); int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return ss.str(); return ss.str();
} }
const Variable* OpRunContext::Input(int index) const {
return scope_->GetVariable(op_->inputs_[index]);
}
Variable* OpRunContext::Output(int index) const {
return scope_->GetVariable(op_->outputs_[index]);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save