Merge remote-tracking branch 'baidu/develop' into tensor_to_EigenTensor

cblas_new
qijun 8 years ago
commit 71e2a94310

3
.gitignore vendored

@ -22,3 +22,6 @@ cmake-build-*
# generated while compiling # generated while compiling
python/paddle/v2/framework/core.so python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake

@ -140,6 +140,10 @@ endif(USE_NNPACK)
add_subdirectory(proto) add_subdirectory(proto)
# "add_subdirectory(go)" should be placed after the following loine,
# because it depends on paddle/optimizer.
add_subdirectory(paddle/optimizer)
# "add_subdirectory(paddle)" and "add_subdirectory(python)" should be # "add_subdirectory(paddle)" and "add_subdirectory(python)" should be
# placed after this block, because they depends on it. # placed after this block, because they depends on it.
if(WITH_GOLANG) if(WITH_GOLANG)

@ -93,7 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE) if(NOT APPLE)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT}) link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl") set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl -lrt")
endif(NOT APPLE) endif(NOT APPLE)
function(merge_static_libs TARGET_NAME) function(merge_static_libs TARGET_NAME)
@ -301,7 +301,7 @@ function(go_library TARGET_NAME)
file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go") file(GLOB GO_SOURCE RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*.go")
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND rm "${${TARGET_NAME}_LIB_PATH}" COMMAND rm "${${TARGET_NAME}_LIB_PATH}"
# Golang build source code # Golang build source code
@ -309,7 +309,7 @@ function(go_library TARGET_NAME)
-o "${${TARGET_NAME}_LIB_PATH}" -o "${${TARGET_NAME}_LIB_PATH}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${GO_SOURCE}"
# must run under GOPATH # must run under GOPATH
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_dependencies(${TARGET_NAME} go_vendor) add_dependencies(${TARGET_NAME} go_vendor)
endfunction(go_library) endfunction(go_library)
@ -320,14 +320,11 @@ function(go_binary TARGET_NAME)
cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) string(REPLACE "${PADDLE_GO_PATH}/" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
# FIXME: link path
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp add_custom_command(OUTPUT ${TARGET_NAME}_timestamp
COMMAND env LIBRARY_PATH=${CMAKE_BINARY_DIR}/go/pserver/client/c/:$ENV{LIBRARY_PATH} COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
"./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}" "./${CMAKE_CURRENT_SOURCE_REL_DIR}/${go_binary_SRCS}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
# TODO: don't know what ${TARGET_NAME}_link does
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS}) add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${TARGET_NAME}_timestamp ${go_binary_DEPS})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin) install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin)
endfunction(go_binary) endfunction(go_binary)
@ -335,15 +332,18 @@ endfunction(go_binary)
function(go_test TARGET_NAME) function(go_test TARGET_NAME)
set(options OPTIONAL) set(options OPTIONAL)
set(oneValueArgs "") set(oneValueArgs "")
set(multiValueArgs SRCS DEPS) set(multiValueArgs DEPS)
cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cmake_parse_arguments(go_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_custom_command(OUTPUT ${TARGET_NAME}_timestamp string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS})
add_custom_command(TARGET ${TARGET_NAME} POST_BUILD
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" -c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_test_SRCS} ".${CMAKE_CURRENT_SOURCE_REL_DIR}"
WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go")
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test) endfunction(go_test)
function(proto_library TARGET_NAME) function(proto_library TARGET_NAME)

@ -17,3 +17,7 @@ add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver) add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master) add_subdirectory(cmd/master)
add_subdirectory(master/c) add_subdirectory(master/c)
add_subdirectory(master)
add_subdirectory(pserver)
add_subdirectory(pserver/client)
add_subdirectory(utils/networkhelper)

@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
go_binary(master SRC master.go DEPS paddle_go_optimizer) go_binary(master SRC master.go)

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -18,53 +19,47 @@ func main() {
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
level, err := log.ParseLevel(*logLevel) level, err := log.ParseLevel(*logLevel)
if err != nil { candy.Must(err)
panic(err)
}
log.SetLevel(level) log.SetLevel(level)
var idx int var idx int
var cp pserver.Checkpoint
var cp *pserver.Checkpoint
var e *pserver.EtcdClient var e *pserver.EtcdClient
if *index >= 0 { if *index >= 0 {
idx = *index idx = *index
} else { } else {
timeout := time.Second * time.Duration((*etcdTimeout)) e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout)
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register() idx, err = e.Register()
candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e)
if err != nil { if err != nil {
panic(err) log.Errorf("Fetch checkpoint failed, %s", err)
} }
} }
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil { candy.Must(err)
panic(err)
}
err = rpc.Register(s) err = rpc.Register(s)
if err != nil { candy.Must(err)
panic(err)
}
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { candy.Must(err)
panic(err)
}
log.Infof("start pserver at port %d", *port) log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil) err = http.Serve(l, nil)
candy.Must(err)
if err != nil {
panic(err)
}
} }

14
go/glide.lock generated

@ -1,8 +1,8 @@
hash: b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7 hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
updated: 2017-06-27T14:05:48.925262819+08:00 updated: 2017-07-11T10:04:40.786745417+08:00
imports: imports:
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: 61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7 version: cb2a496c4ddd1c87a9f280e116649b599999ec79
subpackages: subpackages:
- auth/authpb - auth/authpb
- clientv3 - clientv3
@ -22,7 +22,9 @@ imports:
- name: github.com/PaddlePaddle/recordio - name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129 version: edfb82af0739c84f241c87390ec5649c7b28c129
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: golang.org/x/net - name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2 version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages: subpackages:
@ -34,11 +36,11 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733 version: abf9c25f54453410d0c6668e519582a9e1115027
subpackages: subpackages:
- unix - unix
- name: golang.org/x/text - name: golang.org/x/text
version: 4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77 version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
subpackages: subpackages:
- secure/bidirule - secure/bidirule
- transform - transform

@ -10,3 +10,4 @@ import:
version: ^1.7.4-pre version: ^1.7.4-pre
- package: github.com/sirupsen/logrus - package: github.com/sirupsen/logrus
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(master_test)
endif()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer)
endif()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer)
endif()

@ -0,0 +1 @@
libpaddle_go_optimizer.a

@ -1,5 +1,13 @@
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m) target_link_libraries(paddle_go_optimizer stdc++ m)
# Copy library to the required place.
# See: go/pserver/optimizer.go:
# // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
add_custom_command(TARGET paddle_go_optimizer POST_BUILD
COMMAND cp "${CMAKE_CURRENT_BINARY_DIR}/libpaddle_go_optimizer.a" "${CMAKE_CURRENT_SOURCE_DIR}"
)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING) if(WITH_TESTING)
# FIXME: this test requires pserver which is not managed by the test # FIXME: this test requires pserver which is not managed by the test

@ -16,7 +16,7 @@ import (
const ( const (
// PsDesired is etcd path for store desired pserver count // PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired" PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr // PsPath is the base dir for pserver to store their addr
PsPath = "/ps/" PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information // PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/" PsCheckpoint = "/checkpoints/"
@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil return idx, nil
} }
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := e.etcdClient.Get(ctx, key)
cancel()
if err != nil {
return []byte{}, err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return []byte{}, nil
}
v := kvs[0].Value
return v, nil
}
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.etcdClient.Put(ctx, key, string(value)) _, err := e.etcdClient.Put(ctx, key, string(value))
cancel() cancel()
if err != nil { if err != nil {

@ -1,8 +1,7 @@
package pserver package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
// //FIXME: ldflags contain "build" path // #cgo LDFLAGS: ${SRCDIR}/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #cgo LDFLAGS: ${SRCDIR}/../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ -lm
// #include "paddle/optimizer/optimizer.h" // #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h> // #include <stdlib.h>
// #include <string.h> // #include <string.h>

@ -9,6 +9,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strconv" "strconv"
@ -21,14 +22,14 @@ import (
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
type ElementType int type ElementType int
// RPC error message.
const ( const (
// AlreadyInitialized is true if pserver is initialized AlreadyInitialized = "pserver already initialized"
AlreadyInitialized = "pserver already initialized" Uninitialized = "pserver not fully initialized"
// Uninitialized is true if pserver not fully initialized CheckpointMD5Failed = "checkpoint file MD5 validation failed"
Uninitialized = "pserver not fully initialized"
) )
// Supported element types // Supported element types.
const ( const (
Int32 ElementType = iota Int32 ElementType = iota
UInt32 UInt32
@ -51,21 +52,15 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format Config []byte // parameter configuration in Proto Buffer format
} }
// ParameterCheckpoint is Parameter and State checkpoint // checkpointMeta saves checkpoint metadata
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
type checkpointMeta struct { type checkpointMeta struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Md5sum string `json:"md5sum"` MD5 string `json:"md5"`
Timestamp string `json:"timestamp"` Timestamp int64 `json:"timestamp"`
} }
// Checkpoint is the pserver shard persist in file // Checkpoint is the pserver shard persist in file
type Checkpoint []ParameterCheckpoint type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
type Gradient Parameter type Gradient Parameter
@ -81,12 +76,53 @@ type Service struct {
optMap map[string]*optimizer optMap map[string]*optimizer
} }
// parameterCheckpoint saves parameter checkpoint
type parameterCheckpoint struct {
ParameterWithConfig
State []byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) {
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil {
return nil, err
}
var cpMeta checkpointMeta
if err = json.Unmarshal(v, &cpMeta); err != nil {
return nil, err
}
fn := filepath.Join(cpPath, cpMeta.UUID)
if _, err = os.Stat(fn); os.IsNotExist(err) {
return nil, err
}
content, err := ioutil.ReadFile(fn)
if err != nil {
return nil, err
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 {
return nil, errors.New(CheckpointMD5Failed)
}
dec := gob.NewDecoder(bytes.NewReader(content))
cp := &Checkpoint{}
if err = dec.Decode(cp); err != nil {
return nil, err
}
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) {
s := &Service{ s := &Service{
idx: idx, idx: idx,
checkpointInterval: time.Second * time.Duration(seconds), checkpointInterval: interval,
checkpointPath: path, checkpointPath: path,
client: client, client: client,
} }
@ -94,10 +130,12 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
if cp != nil { if cp != nil {
for _, item := range cp { for _, item := range *cp {
p := item.ParamConfig p := ParameterWithConfig{
st := item.State Param: item.Param,
s.optMap[p.Param.Name] = newOptimizer(p, st) Config: item.Config,
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
} }
} }
return s, nil return s, nil
@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
cp := make([]ParameterCheckpoint, 0, len(s.optMap)) cp := make([]parameterCheckpoint, len(s.optMap))
index := 0 index := 0
for name, opt := range s.optMap { for name, opt := range s.optMap {
var pc ParameterCheckpoint var pc parameterCheckpoint
pc.ParamConfig.Param.Name = name pc.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType pc.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights() pc.Param.Content = opt.GetWeights()
pc.State = opt.GetStates() pc.State = opt.GetStates()
cp[index] = pc cp[index] = pc
index++ index++
@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error {
cpMeta := checkpointMeta{} cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String() cpMeta.Timestamp = time.Now().UnixNano()
h := md5.New() h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, _ := json.Marshal(cpMeta) cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
if err != nil { if err != nil {
return err return err
} }
@ -219,7 +257,11 @@ func (s *Service) doCheckpoint() error {
log.Info("checkpoint does not exists.") log.Info("checkpoint does not exists.")
} else { } else {
err = os.Remove(cpMeta.UUID) err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) if err != nil {
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
} else {
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
} }
f, err := os.Create(cpMeta.UUID) f, err := os.Create(cpMeta.UUID)
defer f.Close() defer f.Close()

@ -0,0 +1,3 @@
if(WITH_TESTING)
go_test(network_helper_test)
endif()

@ -8,14 +8,12 @@ add_subdirectory(gserver)
add_subdirectory(pserver) add_subdirectory(pserver)
add_subdirectory(trainer) add_subdirectory(trainer)
add_subdirectory(scripts) add_subdirectory(scripts)
add_subdirectory(optimizer)
add_subdirectory(string) add_subdirectory(string)
if(Boost_FOUND) if(Boost_FOUND)
add_subdirectory(memory) add_subdirectory(memory)
add_subdirectory(platform) add_subdirectory(platform)
add_subdirectory(framework) add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind) add_subdirectory(pybind)
endif() endif()

@ -12,7 +12,7 @@ cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type) proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf) cc_library(operator SRCS operator.cc DEPS op_desc protobuf)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry place)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)

@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
return ((0 <= idx.head) && (idx.head < size.head)); return ((0 <= idx.head) && (idx.head < size.head));
} }
/**
* \brief Check if a size and a stride create a Fortran order contiguous
* block of memory.
*/
template <int i>
HOST bool contiguous(const Dim<i>& size, const Dim<i>& stride, int mul = 1) {
if (product(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return (get<0>(stride) == contiguous_stride &&
contiguous(size.tail, stride.tail, mul * get<0>(size)));
}
///\cond HIDDEN
// Base case of contiguous, check the nth stride is the size of
// the prefix multiply of n-1 dims.
template <>
inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) {
if (get<0>(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return get<0>(stride) == contiguous_stride;
}
///\endcond
/** /**
* \brief Compute exclusive prefix-multiply of a Dim. * \brief Compute exclusive prefix-multiply of a Dim.
*/ */
@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
} }
///\endcond ///\endcond
/**
* \brief Calculate strides of a contiguous array of the given size
*
* Sets the stride for any dimension with an extent of 1 to 0.
* \param size Dim object containing the size of the array.
* \param base The base stride to use.
* \return Dim object the same size as \p size with the strides.
*/
template <int i>
HOSTDEVICE Dim<i> contiguous_strides(const Dim<i>& size, int base = 1) {
int stride = size.head == 1 ? 0 : base;
return Dim<i>(stride, contiguous_strides(size.tail, base * size.head));
}
///\cond HIDDEN
// Base case of contiguous_strides
template <>
HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) {
int stride = size.head == 1 ? 0 : base;
return Dim<1>(stride);
}
///\endcond
/** /**
* Add two dimensions together * Add two dimensions together
*/ */

@ -58,24 +58,6 @@ TEST(Dim, Equality) {
EXPECT_EQ(paddle::framework::get<1>(c), 3); EXPECT_EQ(paddle::framework::get<1>(c), 3);
EXPECT_EQ(paddle::framework::get<2>(c), 12); EXPECT_EQ(paddle::framework::get<2>(c), 12);
// contiguous_strides
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 1, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 0);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 10, 1));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 10);
EXPECT_EQ(paddle::framework::get<2>(c), 0);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(1, 10, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 0);
EXPECT_EQ(paddle::framework::get<1>(c), 1);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(2, 3, 4));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 2);
EXPECT_EQ(paddle::framework::get<2>(c), 6);
// generate from an index // generate from an index
auto size = paddle::framework::make_dim(4, 5, 2); auto size = paddle::framework::make_dim(4, 5, 2);
c = paddle::framework::Dim<3>(14, size); c = paddle::framework::Dim<3>(14, size);
@ -101,16 +83,6 @@ TEST(Dim, Bool) {
EXPECT_TRUE(a == a); EXPECT_TRUE(a == a);
EXPECT_FALSE(a == b); EXPECT_FALSE(a == b);
EXPECT_TRUE(a == c); EXPECT_TRUE(a == c);
// contiguous check
int x = 4, y = 5, z = 2;
paddle::framework::Dim<3> sizef(x, y, z);
paddle::framework::Dim<3> stridea(1, x, x*y);
paddle::framework::Dim<3> strideb(2, 2*x, 2*x*y);
paddle::framework::Dim<3> stridec(1, x, 2*x*y);
EXPECT_TRUE(paddle::framework::contiguous(sizef, stridea));
EXPECT_FALSE(paddle::framework::contiguous(sizef, strideb));
EXPECT_FALSE(paddle::framework::contiguous(sizef, stridec));
} }
TEST(Dim, Print) { TEST(Dim, Print) {

@ -147,13 +147,13 @@ class OpRegisterHelper {
} }
}; };
#define REGISTER_OP(__op_class, __op_maker_class, __op_type) \ #define REGISTER_OP(type, op_class, op_maker_class) \
class __op_class##Register { \ class op_class##Register { \
private: \ private: \
const static OpRegisterHelper<__op_class, __op_maker_class> reg; \ const static OpRegisterHelper<op_class, op_maker_class> reg; \
}; \ }; \
const OpRegisterHelper<__op_class, __op_maker_class> \ const OpRegisterHelper<op_class, op_maker_class> op_class##Register::reg( \
__op_class##Register::reg(#__op_type); #type)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -1,17 +1,15 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"
using namespace paddle::framework; using namespace paddle::framework;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class CosineOp : public OperatorWithKernel { class CosineOp : public OperatorBase {
public: public:
void Run(const OpRunContext* context) const override { void Run(const std::shared_ptr<Scope>& scope,
printf("%s\n", DebugString().c_str()); const platform::DeviceContext& dev_ctx) const override {}
} void InferShape(const std::shared_ptr<Scope>& scope) const override {}
}; };
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@ -28,14 +26,15 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} }
}; };
REGISTER_OP(CosineOp, CosineOpProtoAndCheckerMaker, cos_sim) REGISTER_OP(cos_sim, CosineOp, CosineOpProtoAndCheckerMaker);
class MyTestOp : public OperatorBase {
public:
void InferShape(const std::shared_ptr<Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override {}
class MyTestOp : public OperatorWithKernel {
public: public:
void Run(const OpRunContext* ctx) const override {
printf("%s\n", DebugString().c_str());
printf("test_attr = %d\n", ctx->op_->GetAttr<int>("test_attr"));
}
}; };
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
@ -54,7 +53,7 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
} }
}; };
REGISTER_OP(MyTestOp, MyTestOpProtoAndCheckerMaker, my_test_op) REGISTER_OP(my_test_op, MyTestOp, MyTestOpProtoAndCheckerMaker);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale"); float scale_get = op->GetAttr<float>("scale");
ASSERT_EQ(scale_get, scale); ASSERT_EQ(scale_get, scale);
} }
@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->GetAttr<float>("scale"), 1.0);
} }
@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(4); attr->set_i(4);
paddle::framework::OperatorBase* op = paddle::framework::OperatorBase* op =
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc);
auto dev_ctx = DeviceContext(); paddle::platform::CPUDeviceContext dev_ctx;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
op->Run(scope, &dev_ctx); op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr"); int test_attr = op->GetAttr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }

@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return ss.str(); return ss.str();
} }
const Variable* OpRunContext::Input(int index) const {
return scope_->GetVariable(op_->inputs_[index]);
}
Variable* OpRunContext::Output(int index) const {
return scope_->GetVariable(op_->outputs_[index]);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -14,44 +14,22 @@ limitations under the License. */
#pragma once #pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp> #include <boost/variant.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase; class OperatorBase;
class DeviceContext {};
/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class OpRunContext {
public:
OpRunContext(const OperatorBase* op, const std::shared_ptr<Scope> scope,
const DeviceContext* device_context)
: op_(op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const;
Variable* Output(int index) const;
public:
const OperatorBase* op_;
const std::shared_ptr<Scope> scope_;
const DeviceContext* device_context_;
};
/** /**
* OperatorBase has the basic element that Net will call to do computation. * OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User * Only CreateOperator from OpRegistry will new Operator directly. User
@ -77,7 +55,10 @@ class OperatorBase {
/// Net will call this function to Run an op. /// Net will call this function to Run an op.
virtual void Run(const std::shared_ptr<Scope>& scope, virtual void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const = 0; const platform::DeviceContext& dev_ctx) const = 0;
protected:
std::string Type() const { return desc_.type(); }
public: public:
OpDesc desc_; OpDesc desc_;
@ -86,22 +67,84 @@ class OperatorBase {
AttributeMap attrs_; AttributeMap attrs_;
}; };
class OpKernel {
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class KernelContext {
public:
KernelContext(const OperatorBase* op, const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& device_context)
: op_(*op), scope_(scope), device_context_(device_context) {}
const Variable* Input(int index) const {
return scope_->GetVariable(op_.inputs_[index]);
}
Variable* Output(int index) const {
return scope_->GetVariable(op_.outputs_[index]);
}
const OperatorBase& op_;
const std::shared_ptr<Scope>& scope_;
const platform::DeviceContext& device_context_;
};
virtual void Compute(const KernelContext& context) const = 0;
virtual ~OpKernel() {}
};
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
virtual ~OperatorWithKernel() {} struct OpKernelKey {
platform::Place place_;
virtual void InferShape(const std::shared_ptr<Scope>& scope) const {} OpKernelKey() = default;
OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
};
struct OpKernelHash {
std::hash<bool> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
}
};
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const DeviceContext* dev_ctx) const { const platform::DeviceContext& dev_ctx) const final {
OpRunContext op_ctx(this, scope, dev_ctx); auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx));
Run(&op_ctx); opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx));
} }
/// when implement an Op, your should implement this function. static std::unordered_map<std::string /* op_type */, OpKernelMap>&
/// this function should be moved to OpKernel later AllOpKernels() {
virtual void Run(const OpRunContext* context) const = 0; static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
return g_all_op_kernels;
};
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save