Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into crop_layer

cblas_new
wanghaoshuang 8 years ago
commit de5ded6bbd

@ -21,3 +21,10 @@
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29 sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
hooks: hooks:
- id: clang-formater - id: clang-formater
- repo: https://github.com/dnephin/pre-commit-golang
sha: e4693a4c282b4fc878eda172a929f7a6508e7d16
hooks:
- id: go-fmt
files: (.*\.go)
- id: go-lint
files: (.*\.go)

@ -37,12 +37,13 @@ before_install:
# protobuf version. # protobuf version.
- pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker - pip install numpy wheel 'protobuf==3.1' sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit requests==2.9.2 LinkChecker
- pip install rarfile - pip install rarfile
- curl https://glide.sh/get | bash
- eval "$(GIMME_GO_VERSION=1.8.3 gimme)" - eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
- | - |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script: script:
- | - |
export WITH_GOLANG=ON && timeout 2580 paddle/scripts/travis/${JOB}.sh # 43min timeout timeout 2580 paddle/scripts/travis/${JOB}.sh # 43min timeout
RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else false; fi; RESULT=$?; if [ $RESULT -eq 0 ] || [ $RESULT -eq 142 ]; then true; else false; fi;
notifications: notifications:
email: email:

@ -16,6 +16,7 @@ cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR}) set(PROJ_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
set(PROJ_BINARY_ROOT ${CMAKE_CURRENT_BINARY_DIR})
include(system) include(system)

@ -38,12 +38,14 @@ ExternalProject_Add(
CMAKE_ARGS -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} CMAKE_ARGS -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
CMAKE_ARGS -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} CMAKE_ARGS -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${GLOG_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_INSTALL_LIBDIR=${GLOG_INSTALL_DIR}/lib
CMAKE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE=ON CMAKE_ARGS -DCMAKE_POSITION_INDEPENDENT_CODE=ON
CMAKE_ARGS -DWITH_GFLAGS=ON CMAKE_ARGS -DWITH_GFLAGS=ON
CMAKE_ARGS -Dgflags_DIR=${GFLAGS_INSTALL_DIR}/lib/cmake/gflags CMAKE_ARGS -Dgflags_DIR=${GFLAGS_INSTALL_DIR}/lib/cmake/gflags
CMAKE_ARGS -DBUILD_TESTING=OFF CMAKE_ARGS -DBUILD_TESTING=OFF
CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOG_INSTALL_DIR} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GLOG_INSTALL_DIR}
-DCMAKE_INSTALL_LIBDIR:PATH=${GLOG_INSTALL_DIR}/lib
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_BUILD_TYPE:STRING=Release
) )

@ -17,6 +17,65 @@ INCLUDE(ExternalProject)
FIND_PACKAGE(Protobuf QUIET) FIND_PACKAGE(Protobuf QUIET)
SET(PROTOBUF_FOUND "OFF") SET(PROTOBUF_FOUND "OFF")
if(NOT COMMAND protobuf_generate_python) # before cmake 3.4, protobuf_genrerate_python is not defined.
function(protobuf_generate_python SRCS)
# shameless copy from https://github.com/Kitware/CMake/blob/master/Modules/FindProtobuf.cmake
if(NOT ARGN)
message(SEND_ERROR "Error: PROTOBUF_GENERATE_PYTHON() called without any proto files")
return()
endif()
if(PROTOBUF_GENERATE_CPP_APPEND_PATH)
# Create an include path for each file specified
foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(ABS_PATH ${ABS_FIL} PATH)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
else()
set(_protobuf_include_path -I ${CMAKE_CURRENT_SOURCE_DIR})
endif()
if(DEFINED PROTOBUF_IMPORT_DIRS AND NOT DEFINED Protobuf_IMPORT_DIRS)
set(Protobuf_IMPORT_DIRS "${PROTOBUF_IMPORT_DIRS}")
endif()
if(DEFINED Protobuf_IMPORT_DIRS)
foreach(DIR ${Protobuf_IMPORT_DIRS})
get_filename_component(ABS_PATH ${DIR} ABSOLUTE)
list(FIND _protobuf_include_path ${ABS_PATH} _contains_already)
if(${_contains_already} EQUAL -1)
list(APPEND _protobuf_include_path -I ${ABS_PATH})
endif()
endforeach()
endif()
set(${SRCS})
foreach(FIL ${ARGN})
get_filename_component(ABS_FIL ${FIL} ABSOLUTE)
get_filename_component(FIL_WE ${FIL} NAME_WE)
if(NOT PROTOBUF_GENERATE_CPP_APPEND_PATH)
get_filename_component(FIL_DIR ${FIL} DIRECTORY)
if(FIL_DIR)
set(FIL_WE "${FIL_DIR}/${FIL_WE}")
endif()
endif()
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py")
add_custom_command(
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}_pb2.py"
COMMAND ${Protobuf_PROTOC_EXECUTABLE} --python_out ${CMAKE_CURRENT_BINARY_DIR} ${_protobuf_include_path} ${ABS_FIL}
DEPENDS ${ABS_FIL} ${Protobuf_PROTOC_EXECUTABLE}
COMMENT "Running Python protocol buffer compiler on ${FIL}"
VERBATIM )
endforeach()
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
endfunction()
endif()
# Print and set the protobuf library information, # Print and set the protobuf library information,
# finish this cmake process and exit from this file. # finish this cmake process and exit from this file.

@ -88,7 +88,7 @@
# #
# including binary directory for generated headers. # including binary directory for generated headers.
include_directories(${CMAKE_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
@ -99,25 +99,44 @@ function(merge_static_libs TARGET_NAME)
set(libs ${ARGN}) set(libs ${ARGN})
list(REMOVE_DUPLICATES libs) list(REMOVE_DUPLICATES libs)
# First get the file names of the libraries to be merged # Get all propagation dependencies from the merged libraries
foreach(lib ${libs}) foreach(lib ${libs})
set(libfiles ${libfiles} $<TARGET_FILE:${lib}>) list(APPEND libs_deps ${${lib}_LIB_DEPENDS})
endforeach() endforeach()
if(APPLE) # Use OSX's libtool to merge archives if(APPLE) # Use OSX's libtool to merge archives
# To produce a library we need at least one source file.
# It is created by add_custom_command below and will helps
# also help to track dependencies.
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}_dummy.c)
# Make the generated dummy source file depended on all static input
# libs. If input lib changes,the source file is touched
# which causes the desired effect (relink).
add_custom_command(OUTPUT ${dummyfile}
COMMAND ${CMAKE_COMMAND} -E touch ${dummyfile}
DEPENDS ${libs})
# Generate dummy staic lib
file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") file(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";")
add_library(${TARGET_NAME} STATIC ${dummyfile}) add_library(${TARGET_NAME} STATIC ${dummyfile})
target_link_libraries(${TARGET_NAME} ${libs_deps})
foreach(lib ${libs})
# Get the file names of the libraries to be merged
set(libfiles ${libfiles} $<TARGET_FILE:${lib}>)
endforeach()
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a"
COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles}) COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles})
else() # general UNIX: use "ar" to extract objects and re-add to a common lib else() # general UNIX: use "ar" to extract objects and re-add to a common lib
foreach(lib ${libs}) foreach(lib ${libs})
set(objlistfile ${lib}.objlist) # list of objects in the input library set(objlistfile ${lib}.objlist) # list of objects in the input library
set(objdir ${lib}.objdir) set(objdir ${lib}.objdir)
add_custom_command(OUTPUT ${objdir} add_custom_command(OUTPUT ${objdir}
COMMAND ${CMAKE_COMMAND} -E make_directory ${objdir}) COMMAND ${CMAKE_COMMAND} -E make_directory ${objdir}
DEPENDS ${lib})
add_custom_command(OUTPUT ${objlistfile} add_custom_command(OUTPUT ${objlistfile}
COMMAND ${CMAKE_AR} -x "$<TARGET_FILE:${lib}>" COMMAND ${CMAKE_AR} -x "$<TARGET_FILE:${lib}>"
@ -134,18 +153,18 @@ function(merge_static_libs TARGET_NAME)
list(APPEND mergebases "${mergebase}") list(APPEND mergebases "${mergebase}")
endforeach() endforeach()
# We need a target for the output merged library
add_library(${TARGET_NAME} STATIC ${mergebases}) add_library(${TARGET_NAME} STATIC ${mergebases})
target_link_libraries(${TARGET_NAME} ${libs_deps})
# Get the file name of the generated library
set(outlibfile "$<TARGET_FILE:${TARGET_NAME}>") set(outlibfile "$<TARGET_FILE:${TARGET_NAME}>")
foreach(lib ${libs}) foreach(lib ${libs})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND ${CMAKE_AR} ru ${outlibfile} @"../${lib}.objlist" COMMAND ${CMAKE_AR} cr ${outlibfile} *.o
WORKING_DIRECTORY ${lib}.objdir) COMMAND ${CMAKE_RANLIB} ${outlibfile}
WORKING_DIRECTORY ${lib}.objdir)
endforeach() endforeach()
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND ${CMAKE_RANLIB} ${outlibfile})
endif() endif()
endfunction(merge_static_libs) endfunction(merge_static_libs)
@ -192,7 +211,7 @@ function(cc_test TARGET_NAME)
set(multiValueArgs SRCS DEPS) set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_SRCS}) add_executable(${TARGET_NAME} ${cc_test_SRCS})
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} gtest gtest_main -lstdc++ -lm) target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} gtest gtest_main)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} gtest gtest_main) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} gtest gtest_main)
add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif() endif()
@ -285,7 +304,7 @@ function(go_library TARGET_NAME)
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${${TARGET_NAME}_LIB_PATH}" COMMAND rm "${${TARGET_NAME}_LIB_PATH}"
# Golang build source code # Golang build source code
COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH} GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE} COMMAND GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${${TARGET_NAME}_LIB_PATH}" -o "${${TARGET_NAME}_LIB_PATH}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}"
# must run under GOPATH # must run under GOPATH
@ -335,3 +354,12 @@ function(proto_library TARGET_NAME)
protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS}) protobuf_generate_cpp(proto_srcs proto_hdrs ${proto_library_SRCS})
cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS ${proto_library_DEPS} protobuf) cc_library(${TARGET_NAME} SRCS ${proto_srcs} DEPS ${proto_library_DEPS} protobuf)
endfunction() endfunction()
function(py_proto_compile TARGET_NAME)
set(oneValueArgs "")
set(multiValueArgs SRCS)
cmake_parse_arguments(py_proto_compile "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(py_srcs)
protobuf_generate_python(py_srcs ${py_proto_compile_SRCS})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${py_srcs})
endfunction()

@ -101,7 +101,7 @@
</div> </div>
<div class="site-nav-links"> <div class="site-nav-links">
<div class="site-menu"> <div class="site-menu">
<a class="fork-on-github" href="https://github.com/PaddlePaddle/Paddle" target="_blank"><i class="fa fa-github"></i>Folk me on Github</a> <a class="fork-on-github" href="https://github.com/PaddlePaddle/Paddle" target="_blank"><i class="fa fa-github"></i>Fork me on Github</a>
<div class="language-switcher dropdown"> <div class="language-switcher dropdown">
<a type="button" data-toggle="dropdown"> <a type="button" data-toggle="dropdown">
<span>English</span> <span>English</span>

@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
@ -31,18 +33,20 @@ func main() {
log.SetLevel(level) log.SetLevel(level)
var idx int var idx int
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
timeout := time.Second * time.Duration((*etcdTimeout)) timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register() idx, err = e.Register()
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
s, err := pserver.NewService(idx) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }

@ -104,11 +104,22 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return C.PADDLE_MASTER_OK return C.PADDLE_MASTER_OK
} }
// return value:
// 0:ok
// -1:error
//export paddle_next_record //export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client) c := get(client)
r := c.NextRecord() r, err := c.NextRecord()
if err != nil {
// Error
// TODO: return the type of error?
*record = (*C.uchar)(nullPtr)
return -1
}
if len(r) == 0 { if len(r) == 0 {
// Empty record
*record = (*C.uchar)(nullPtr) *record = (*C.uchar)(nullPtr)
return 0 return 0
} }

@ -11,7 +11,12 @@ import (
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
ch chan []byte ch chan record
}
type record struct {
r []byte
err error
} }
// NewClient creates a new Client. // NewClient creates a new Client.
@ -21,7 +26,7 @@ type Client struct {
func NewClient(addrCh <-chan string, bufSize int) *Client { func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan []byte, bufSize) c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh) go c.monitorMaster(addrCh)
go c.getRecords() go c.getRecords()
return c return c
@ -46,10 +51,11 @@ func (c *Client) getRecords() {
s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1) s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() { for s.Scan() {
c.ch <- s.Record() c.ch <- record{s.Record(), nil}
} }
if s.Err() != nil { if s.Err() != nil {
c.ch <- record{nil, s.Err()}
log.Errorln(err, chunk.Path) log.Errorln(err, chunk.Path)
} }
@ -116,6 +122,7 @@ func (c *Client) taskFinished(taskID int) error {
// //
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
// thread-safe. // thread-safe.
func (c *Client) NextRecord() []byte { func (c *Client) NextRecord() ([]byte, error) {
return <-c.ch r := <-c.ch
return r.r, r.err
} }

@ -68,12 +68,17 @@ func TestNextRecord(t *testing.T) {
for pass := 0; pass < 50; pass++ { for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool) received := make(map[byte]bool)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
r := c.NextRecord() r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
}
if len(r) != 1 { if len(r) != 1 {
t.Fatal("Length should be 1.", r) t.Fatal(pass, i, "Length should be 1.", r)
} }
if received[r[0]] { if received[r[0]] {
t.Fatal("Received duplicate.", received, r) t.Fatal(pass, i, "Received duplicate.", received, r)
} }
received[r[0]] = true received[r[0]] = true
} }

@ -1,6 +1,8 @@
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING) if(WITH_TESTING)
# TODO: add unit test # FIXME: this test requires pserver which is not managed by the test
#add_subdirectory(test) # we need some kind of e2e testing machanism.
# add_subdirectory(test)
endif() endif()

@ -18,6 +18,8 @@ const (
PsDesired = "/ps_desired" PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr // PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
) )
// EtcdClient is the etcd client that the pserver uses for fault // EtcdClient is the etcd client that the pserver uses for fault
@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil return idx, nil
} }
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
_, err := e.etcdClient.Put(ctx, key, string(value))
cancel()
if err != nil {
return err
}
return nil
}

@ -1,7 +1,8 @@
package pserver package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
// #cgo LDFLAGS: -lpaddle_go_optimizer -lstdc++ -lm // //FIXME: ldflags contain "build" path
// #cgo LDFLAGS: ${SRCDIR}/../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #include "paddle/optimizer/optimizer.h" // #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h> // #include <stdlib.h>
// #include <string.h> // #include <string.h>
@ -34,29 +35,41 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len] return (*[1 << 30]byte)(p)[:len:len]
} }
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{} o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param p := paramWithConfigs.Param
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": len(p.Content), "ParamSize": len(p.Content),
"ConfigSize": len(c), "ConfigSize": len(c),
"StateSize": len(s),
}).Info("New Optimizer Created with config:") }).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content))) cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content))) C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s)))
(*C.char)(nullPtr), 0)
return o return o
} }
func (o *optimizer) GetWeights() []byte { func (o *optimizer) GetWeights() []byte {
var buffer unsafe.Pointer var buffer unsafe.Pointer
buffer_len := C.paddle_optimizer_get_weights(o.opt, &buffer) bufferLen := C.paddle_optimizer_get_weights(o.opt, &buffer)
return cArrayToSlice(buffer, int(buffer_len)*C.sizeof_float) return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
}
func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len))
} }
func (o *optimizer) UpdateParameter(g Gradient) error { func (o *optimizer) UpdateParameter(g Gradient) error {

@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param: p, Param: p,
Config: config, Config: config,
} }
o := newOptimizer(param) o := newOptimizer(param, nil)
o.Cleanup() o.Cleanup()
} }

@ -1,17 +1,31 @@
package pserver package pserver
import ( import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"os"
"path/filepath"
"strconv"
"sync" "sync"
"time"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
const ( const (
// AlreadyInitialized is true if pserver is initialized
AlreadyInitialized = "pserver already initialized" AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized" // Uninitialized is true if pserver not fully initialized
Uninitialized = "pserver not fully initialized"
) )
// Supported element types // Supported element types
@ -37,26 +51,55 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format Config []byte // parameter configuration in Proto Buffer format
} }
// ParameterCheckpoint is Parameter and State checkpoint
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
type checkpointMeta struct {
UUID string `json:"uuid"`
Md5sum string `json:"md5sum"`
Timestamp string `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type Checkpoint []ParameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter type Gradient Parameter
// Service is the RPC service for pserver. // Service is the RPC service for pserver.
type Service struct { type Service struct {
initialized chan struct{} initialized chan struct{}
idx int idx int
checkpointInterval time.Duration
mu sync.Mutex checkpointPath string
optMap map[string]*optimizer client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
} }
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(idx int) (*Service, error) { func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: time.Second * time.Duration(seconds),
checkpointPath: path,
client: client,
} }
s.optMap = make(map[string]*optimizer) s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
if cp != nil {
for _, item := range cp {
p := item.ParamConfig
st := item.State
s.optMap[p.Param.Name] = newOptimizer(p, st)
}
}
return s, nil return s, nil
} }
@ -76,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is // TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory // properly memory aligned, if not, make copy to a memory
// aligned region. // aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs) s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
return nil return nil
} }
@ -137,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil return nil
} }
// Save tells the parameter server to save parameters. // pserver save checkpoint
func (s *Service) Save(path string, dummy *int) error { func (s *Service) doCheckpoint() error {
<-s.initialized <-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
cp := make([]ParameterCheckpoint, 0, len(s.optMap))
index := 0
for name, opt := range s.optMap {
var pc ParameterCheckpoint
pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights()
pc.State = opt.GetStates()
cp[index] = pc
index++
}
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(cp)
if err != nil {
return err
}
cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String()
h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
// TODO cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
if err != nil {
return err
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
f, err := os.Create(cpMeta.UUID)
defer f.Close()
if err != nil {
return err
}
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
writer.Flush()
if err != nil {
return err
}
return nil return nil
} }

@ -15,7 +15,8 @@ const (
) )
func TestServiceFull(t *testing.T) { func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{} ch <- struct{}{}
}() }()
wg.Add(1)
go func() {
err := s.Save("", nil)
if err != nil {
errCh <- err
}
wg.Done()
ch <- struct{}{}
}()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
select { select {
@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}

@ -9,6 +9,10 @@ cc_test(enforce_test SRCS enforce_test.cc)
proto_library(attr_type SRCS attr_type.proto) proto_library(attr_type SRCS attr_type.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attr_type) proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)

@ -0,0 +1,119 @@
#pragma once
#include <boost/variant.hpp>
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/framework/enforce.h"
namespace paddle {
namespace framework {
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>>
Attribute;
typedef std::unordered_map<std::string, Attribute> AttributeMap;
// check whether a value(attribute) fit a certain limit
template <typename T>
class LargerThanChecker {
public:
LargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const {
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail");
}
private:
T lower_bound_;
};
// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...
template <typename T>
class DefaultValueSetter {
public:
DefaultValueSetter(T default_value) : default_value_(default_value) {}
void operator()(T& value) const { value = default_value_; }
private:
T default_value_;
};
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template <typename T>
class TypedAttrChecker {
typedef std::function<void(T&)> ValueChecker;
public:
TypedAttrChecker(const std::string& attr_name) : attr_name_(attr_name) {}
TypedAttrChecker& LargerThan(const T& lower_bound) {
value_checkers_.push_back(LargerThanChecker<T>(lower_bound));
return *this;
}
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE(default_value_setter_.empty(),
"%s can't have more than one default value!", attr_name_);
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
}
// allow users provide their own checker
TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
value_checkers_.push_back(checker);
return *this;
}
void operator()(AttributeMap& attr_map) const {
if (!attr_map.count(attr_name_)) {
// user do not set this attr
PADDLE_ENFORCE(!default_value_setter_.empty(),
"Attribute '%s' is required!", attr_name_);
// default_value_setter_ has no more than one element
T val;
(default_value_setter_[0])(val);
attr_map[attr_name_] = val;
}
Attribute& attr = attr_map.at(attr_name_);
T& attr_value = boost::get<T>(attr);
for (const auto& checker : value_checkers_) {
checker(attr_value);
}
}
private:
std::string attr_name_;
std::vector<ValueChecker> value_checkers_;
std::vector<ValueChecker> default_value_setter_;
};
// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap&)> AttrChecker;
public:
template <typename T>
TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name) {
attr_checkers_.push_back(TypedAttrChecker<T>(attr_name));
AttrChecker& checker = attr_checkers_.back();
return *(checker.target<TypedAttrChecker<T>>());
}
void Check(AttributeMap& attr_map) const {
for (const auto& checker : attr_checkers_) {
checker(attr_map);
}
}
private:
std::vector<AttrChecker> attr_checkers_;
};
} // namespace framework
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -0,0 +1,122 @@
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
TEST(OpRegistry, CreateOp) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
op_desc.add_inputs("aa");
op_desc.add_outputs("bb");
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(3.3);
paddle::framework::OpBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run();
std::string str = "CosineOp runs! scale = " + std::to_string(3.3);
ASSERT_EQ(str.size(), debug_str.size());
for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]);
}
}
TEST(OpRegistry, IllegalAttr) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
op_desc.add_inputs("aa");
op_desc.add_outputs("bb");
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("scale");
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(-2.0);
bool caught = false;
try {
paddle::framework::OpBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
std::string msg = "larger_than check fail";
const char* err_msg = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(err_msg[i], msg[i]);
}
}
ASSERT_TRUE(caught);
}
TEST(OpRegistry, DefaultValue) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("cos_sim");
op_desc.add_inputs("aa");
op_desc.add_outputs("bb");
paddle::framework::OpBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run();
float default_value = 1.0;
std::string str = "CosineOp runs! scale = " + std::to_string(default_value);
ASSERT_EQ(str.size(), debug_str.size());
for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]);
}
}
TEST(OpRegistry, CustomChecker) {
paddle::framework::OpDesc op_desc;
op_desc.set_type("my_test_op");
op_desc.add_inputs("ii");
op_desc.add_outputs("oo");
// attr 'test_attr' is not set
bool caught = false;
try {
paddle::framework::OpBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
const char* err_msg = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(err_msg[i], msg[i]);
}
}
ASSERT_TRUE(caught);
// set 'test_attr' set to an illegal value
auto attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(3);
caught = false;
try {
paddle::framework::OpBase* op __attribute__((unused)) =
paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
std::string msg = "'test_attr' must be even!";
const char* err_msg = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(err_msg[i], msg[i]);
}
}
ASSERT_TRUE(caught);
// set 'test_attr' set to a legal value
op_desc.mutable_attrs()->Clear();
attr = op_desc.mutable_attrs()->Add();
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4);
paddle::framework::OpBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc);
std::string debug_str = op->Run();
std::string str = "MyTestOp runs! test_attr = " + std::to_string(4);
ASSERT_EQ(str.size(), debug_str.size());
for (size_t i = 0; i < debug_str.length(); ++i) {
ASSERT_EQ(debug_str[i], str[i]);
}
}

@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
const char* AdadeltaOptimizer::SerializeState(int* state_len) { const char* AdadeltaOptimizer::SerializeState(int* state_len) {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
TensorToProto(*accum_delta_, state.mutable_accum_delta()); TensorToProto(*accum_delta_, state.mutable_accum_delta());
TensorToProto(*update_delta_, state.mutable_update_delta()); TensorToProto(*update_delta_, state.mutable_update_delta());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdadeltaOptimizer::DeserializeState(const std::string& str) { void AdadeltaOptimizer::DeserializeState(const std::string& str) {
AdadeltaOptimizerState state; AdadeltaOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);

@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
} }
const char* AdagradOptimizer::SerializeState(int* state_len) { const char* AdagradOptimizer::SerializeState(int* state_len) {
AdagradOptimizerState state; AdagradOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient()); TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdagradOptimizer::DeserializeState(const std::string& str) { void AdagradOptimizer::DeserializeState(const std::string& str) {
AdagradOptimizerState state; AdagradOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.accum_gradient(), accum_gradient_); ProtoToTensor(state.accum_gradient(), accum_gradient_);

@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) {
const char *AdamOptimizer::SerializeState(int *state_len) { const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState state; AdamOptimizerState state;
// TODO(zhihong) : add lr_policy serialization std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
state.set_num_sample_passed(num_sample_passed_); state.set_num_sample_passed(num_sample_passed_);
TensorToProto(*parameter_, state.mutable_parameter()); TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums()); TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys()); TensorToProto(*velocitys_, state.mutable_velocitys());
auto str = state.SerializeAsString(); auto str = state.SerializeAsString();
*state_len = str.size(); *state_len += str.size();
return str.c_str(); return str.c_str();
} }
void AdamOptimizer::DeserializeState(const std::string &str) { void AdamOptimizer::DeserializeState(const std::string &str) {
AdamOptimizerState state; AdamOptimizerState state;
state.ParseFromString(str); state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed(); num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_); ProtoToTensor(state.parameter(), parameter_);

@ -17,36 +17,56 @@ public:
// constant learning rate policy // constant learning rate policy
class ConstLr final : public LrPolicy { class ConstLr final : public LrPolicy {
public: public:
ConstLr(double lr) : learning_rate(lr){}; ConstLr(double lr) : learning_rate_(lr){};
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return learning_rate; return learning_rate_;
}
const char *SerializeState(int *state_len) {
LrPolicyState state;
state.set_learning_rate(learning_rate_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void DeserializeState(const std::string &str) {
LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
} }
const char *SerializeState(int *state_len) { return nullptr; }
void DeserializeState(const std::string &state) {}
private: private:
double learning_rate; double learning_rate_;
}; };
class LinearLr final : public LrPolicy { class LinearLr final : public LrPolicy {
public: public:
LinearLr(double lr, double lr_decay_a, double lr_decay_b) LinearLr(double lr, double lr_decay_a, double lr_decay_b)
: learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {} : learning_rate_(lr), lr_decay_a_(lr_decay_a), lr_decay_b_(lr_decay_b) {}
double LearningRate(const uint64_t num_sample_passed) { double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b); return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed,
lr_decay_b_);
} }
const char *SerializeState(int *state_len) { const char *SerializeState(int *state_len) {
// TODO(zhihong) : add lr_policy serialization LrPolicyState state;
return nullptr; state.set_learning_rate(learning_rate_);
state.set_lr_decay_a(lr_decay_a_);
state.set_lr_decay_b(lr_decay_b_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
} }
void DeserializeState(const std::string &state) { void DeserializeState(const std::string &str) {
// TODO(zhihong) : add lr_policy serialization LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
lr_decay_a_ = state.lr_decay_a();
lr_decay_b_ = state.lr_decay_b();
} }
private: private:
double learning_rate; double learning_rate_;
double lr_decay_a; double lr_decay_a_;
double lr_decay_b; double lr_decay_b_;
}; };
} // namespace optimizer } // namespace optimizer

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save