Merge branch 'develop' of https://github.com/paddlepaddle/paddle into voc_dataset

cblas_new
wanghaoshuang 8 years ago
commit a5239ac7a5

3
.gitignore vendored

@ -19,3 +19,6 @@ third_party/
# clion workspace.
cmake-build-*
# generated while compiling
python/paddle/v2/framework/core.so

@ -97,6 +97,7 @@ include(external/swig) # download, build, install swig
include(external/warpctc) # download, build, install warpctc
include(external/any) # download libn::any
include(external/eigen) # download eigen3
include(external/pybind11) # download pybind11
include(cudnn) # set cudnn libraries, must before configure
include(configure) # add paddle env configuration

@ -2,8 +2,8 @@
[![Build Status](https://travis-ci.org/PaddlePaddle/Paddle.svg?branch=develop)](https://travis-ci.org/PaddlePaddle/Paddle)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://www.paddlepaddle.org/develop/doc/)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://www.paddlepaddle.org/doc_cn/)
[![Documentation Status](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](http://doc.paddlepaddle.org/develop/doc/)
[![Documentation Status](https://img.shields.io/badge/中文文档-最新-brightgreen.svg)](http://doc.paddlepaddle.org/develop/doc_cn/)
[![Coverage Status](https://coveralls.io/repos/github/PaddlePaddle/Paddle/badge.svg?branch=develop)](https://coveralls.io/github/PaddlePaddle/Paddle?branch=develop)
[![Release](https://img.shields.io/github/release/PaddlePaddle/Paddle.svg)](https://github.com/PaddlePaddle/Paddle/releases)
[![License](https://img.shields.io/badge/license-Apache%202-blue.svg)](LICENSE)
@ -61,35 +61,36 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl
## Installation
It is recommended to check out the
[Docker installation guide](http://www.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html)
[Docker installation guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html)
before looking into the
[build from source guide](http://www.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html)
[build from source guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html)
## Documentation
We provide [English](http://www.paddlepaddle.org/develop/doc/) and
[Chinese](http://www.paddlepaddle.org/doc_cn/) documentation.
We provide [English](http://doc.paddlepaddle.org/develop/doc/) and
[Chinese](http://doc.paddlepaddle.org/doc_cn/) documentation.
- [Deep Learning 101](http://book.paddlepaddle.org/index.html)
You might want to start from the this online interactive book that can run in Jupyter Notebook.
- [Distributed Training](http://www.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html)
- [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html)
You can run distributed training jobs on MPI clusters.
- [Distributed Training on Kubernetes](http://www.paddlepaddle.org/develop/doc/howto/usage/k8s/k8s_en.html)
- [Distributed Training on Kubernetes](http://doc.paddlepaddle.org/develop/doc/howto/usage/k8s/k8s_en.html)
You can also run distributed training jobs on Kubernetes clusters.
- [Python API](http://www.paddlepaddle.org/develop/doc/api/index_en.html)
- [Python API](http://doc.paddlepaddle.org/develop/doc/api/index_en.html)
Our new API enables much shorter programs.
- [How to Contribute](http://www.paddlepaddle.org/develop/doc/howto/dev/contribute_to_paddle_en.html)
- [How to Contribute](http://doc.paddlepaddle.org/develop/doc/howto/dev/contribute_to_paddle_en.html)
We appreciate your contributions!
## Ask Questions
You are welcome to submit questions and bug reports as [Github Issues](https://github.com/PaddlePaddle/Paddle/issues).

@ -0,0 +1,30 @@
INCLUDE(ExternalProject)
SET(PYBIND_SOURCE_DIR ${THIRD_PARTY_PATH}/pybind)
INCLUDE_DIRECTORIES(${PYBIND_SOURCE_DIR}/src/extern_pybind/include)
ExternalProject_Add(
extern_pybind
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/pybind/pybind11.git"
GIT_TAG "v2.1.1"
PREFIX ${PYBIND_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/pybind_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";")
add_library(pybind STATIC ${dummyfile})
else()
add_library(pybind INTERFACE)
endif()
add_dependencies(pybind extern_pybind)
LIST(APPEND external_project_dependencies pybind)

@ -18,6 +18,9 @@ INCLUDE(python_module)
FIND_PACKAGE(PythonInterp 2.7)
IF(WITH_PYTHON)
FIND_PACKAGE(PythonLibs 2.7)
# Fixme: Maybe find a static library. Get SHARED/STATIC by FIND_PACKAGE.
ADD_LIBRARY(python SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET python PROPERTY IMPORTED_LOCATION ${PYTHON_LIBRARIES})
ENDIF(WITH_PYTHON)
SET(py_env "")

@ -109,7 +109,9 @@ set(COMMON_FLAGS
-Wno-unused-function
-Wno-error=literal-suffix
-Wno-error=sign-compare
-Wno-error=unused-local-typedefs)
-Wno-error=unused-local-typedefs
-Wno-error=parentheses-equality # Warnings in Pybind11
)
set(GPU_COMMON_FLAGS
-fPIC

@ -93,6 +93,7 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR})
if(NOT APPLE)
find_package(Threads REQUIRED)
link_libraries(${CMAKE_THREAD_LIBS_INIT})
set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -ldl")
endif(NOT APPLE)
function(merge_static_libs TARGET_NAME)

@ -37,7 +37,7 @@
\frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x}
假设 :math:`z = f(W^T x + b)` ,那么
假设 :math:`z = W^T x + b` ,那么
.. math::

@ -37,7 +37,7 @@ Suppose our loss function is :math:`c(y)`, then
\frac{\partial c(y)}{\partial x} = \frac{\partial c(y)}{\partial y} \frac{\partial y}{\partial x}
Suppose :math:`z = f(W^T x + b)`, then
Suppose :math:`z = W^T x + b`, then
.. math::

@ -41,7 +41,7 @@ PaddlePaddle文档需要准备的环境相对较复杂所以我们推荐使
python -c "import py_paddle"
如果提示错误那么用户需要在本地编译安装PaddlePaddle请参考 `源码编译文档 <http://www.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html>`_
如果提示错误那么用户需要在本地编译安装PaddlePaddle请参考 `源码编译文档 <http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html>`_
注意用户在首次编译安装PaddlePaddle时请将WITH_DOC选项关闭。在编译安装正确之后请再次确认py_paddle包已经安装即可进行下一步操作。
如果提示正确,可以执行以下命令编译生成文档,即
@ -68,9 +68,9 @@ PaddlePaddle文档使用 `sphinx`_ 自动生成用户可以参考sphinx教程
如何更新www.paddlepaddle.org文档
================================
开发者给PaddlePaddle代码增加的注释以PR的形式提交到github中提交方式可参见 `贡献文档 <http://paddlepaddle.org/develop/doc_cn/howto/dev/contribute_to_paddle_cn.html>`_
目前PaddlePaddle的develop分支的文档是自动触发更新的用户可以分别查看最新的 `中文文档 <http://www.paddlepaddle.org/develop/doc_cn/>`_
`英文文档 <http://www.paddlepaddle.org/develop/doc/>`_
开发者给PaddlePaddle代码增加的注释以PR的形式提交到github中提交方式可参见 `贡献文档 <http://doc.paddlepaddle.org/develop/doc_cn/howto/dev/contribute_to_paddle_cn.html>`_
目前PaddlePaddle的develop分支的文档是自动触发更新的用户可以分别查看最新的 `中文文档 <http://doc.paddlepaddle.org/develop/doc_cn/>`_
`英文文档 <http://doc.paddlepaddle.org/develop/doc/>`_

@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
@ -31,18 +33,20 @@ func main() {
log.SetLevel(level)
var idx int
var cp pserver.Checkpoint
var e *pserver.EtcdClient
if *index >= 0 {
idx = *index
} else {
timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err = e.Register()
if err != nil {
panic(err)
}
}
s, err := pserver.NewService(idx)
s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp)
if err != nil {
panic(err)
}

@ -68,7 +68,7 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.ID)
c.taskFinished(t.Meta.ID)
}
}
@ -118,6 +118,11 @@ func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}
// TaskFailed tell the master server as task is failed.
func (c *Client) taskFailed(meta TaskMeta) error {
return c.conn.Call("Service.TaskFailed", meta, nil)
}
// NextRecord returns next record in the dataset.
//
// NextRecord will block until the next record is available. It is

@ -95,10 +95,16 @@ func TestGetFinishTask(t *testing.T) {
t.Fatalf("Should get error, pass: %d\n", i)
}
err = c.taskFinished(tasks[0].ID)
err = c.taskFinished(tasks[0].Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
@ -107,7 +113,7 @@ func TestGetFinishTask(t *testing.T) {
tasks = append(tasks, task)
for _, task := range tasks {
err = c.taskFinished(task.ID)
err = c.taskFinished(task.Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}

@ -31,30 +31,36 @@ type Chunk struct {
Index recordio.Index // chunk index
}
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
ID int
Epoch int
}
// Task is the basic unit of data instances assigned to trainers.
type Task struct {
ID int
Meta TaskMeta
Chunks []Chunk
}
type taskEntry struct {
Epoch int
NumTimeout int
Task Task
// A task fails if it's timeout or trainer reports it exits unnormally.
NumFailure int
}
type taskQueues struct {
Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry
Failed []Task
Failed []taskEntry
}
// Service is the master server service.
type Service struct {
chunksPerTask int
timeoutDur time.Duration
timeoutMax int
failureMax int
ready chan struct{}
store Store
@ -73,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
var cur taskEntry
for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id
cur.Task.Meta.ID = id
id++
result = append(result, cur)
cur.Task.Chunks = nil
@ -83,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
}
if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id
cur.Task.Meta.ID = id
result = append(result, cur)
}
@ -91,11 +97,11 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
}
// NewService creates a new service.
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) {
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
s := &Service{}
s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax
s.failureMax = failureMax
s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{})
@ -257,19 +263,10 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return nil
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}
if t.Epoch != epoch {
func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check.
// schedule of this timeout check or failed status report.
return
}
@ -280,17 +277,31 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
}
}()
delete(s.taskQueues.Pending, t.Task.ID)
delete(s.taskQueues.Pending, t.Task.Meta.ID)
t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
t.NumFailure++
if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Failed = append(s.taskQueues.Failed, t)
return
}
log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout)
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}
s.processFailedTask(t, epoch)
}
}
@ -339,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
t := s.taskQueues.Todo[0]
t.Epoch++
t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t
s.taskQueues.Pending[t.Task.Meta.ID] = t
err := s.snapshot()
if err != nil {
return err
}
*task = t.Task
log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID)
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil
}
@ -365,13 +376,12 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t, ok := s.taskQueues.Pending[taskID]
if !ok {
err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err
return nil
}
// task finished, reset timeout
t.NumTimeout = 0
t.NumFailure = 0
s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID)
@ -389,3 +399,22 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
}
return err
}
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[meta.ID]
if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
return nil
}
s.processFailedTask(t, meta.Epoch)
return nil
}

@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
if ts[i].Task.ID != i {
if ts[i].Task.Meta.ID != i {
t.Error(ts[i], i)
}
}

@ -19,7 +19,7 @@ def main():
# create parameters
parameters = paddle.parameters.create(cost)
# create optimizer
# create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0)
#TODO(zhihong) : replace optimizer with new OptimizerConfig

@ -42,7 +42,8 @@ func initClient() [numPserver]int {
ports[i] = p
go func(l net.Listener) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
panic(err)
}
@ -174,7 +175,7 @@ func TestNativeClient(t *testing.T) {
// TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) {
initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true))
etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
ClientTest(t, c2)
}

@ -18,6 +18,8 @@ const (
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint = "/checkpoints/"
)
// EtcdClient is the etcd client that the pserver uses for fault
@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return idx, nil
}
// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
_, err := e.etcdClient.Put(ctx, key, string(value))
cancel()
if err != nil {
return err
}
return nil
}

@ -35,22 +35,30 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len]
}
func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer {
o := &optimizer{}
o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := State
paramBufferSize := C.size_t(len(p.Content) / C.sizeof_float)
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": len(p.Content),
"ParamSize": paramBufferSize,
"ConfigSize": len(c),
"StateSize": len(s),
}).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
cbuffer = C.malloc(paramBufferSize)
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), paramBufferSize)
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float),
(*C.char)(nullPtr), 0)
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
return o
}
@ -60,6 +68,12 @@ func (o *optimizer) GetWeights() []byte {
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
}
func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
}
func (o *optimizer) UpdateParameter(g Gradient) error {
if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)

@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param: p,
Config: config,
}
o := newOptimizer(param)
o := newOptimizer(param, nil)
o.Cleanup()
}

@ -1,9 +1,21 @@
package pserver
import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// ElementType is the type of elements of a Parameter.
@ -39,6 +51,22 @@ type ParameterWithConfig struct {
Config []byte // parameter configuration in Proto Buffer format
}
// ParameterCheckpoint is Parameter and State checkpoint
type ParameterCheckpoint struct {
ParamConfig ParameterWithConfig
State []byte
}
// checkpoint signature
type checkpointMeta struct {
UUID string `json:"uuid"`
Md5sum string `json:"md5sum"`
Timestamp string `json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type Checkpoint []ParameterCheckpoint
// Gradient is the gradient of the parameter.
type Gradient Parameter
@ -46,19 +74,32 @@ type Gradient Parameter
type Service struct {
initialized chan struct{}
idx int
checkpointInterval time.Duration
checkpointPath string
client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(idx int) (*Service, error) {
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: time.Second * time.Duration(seconds),
checkpointPath: path,
client: client,
}
s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{})
if cp != nil {
for _, item := range cp {
p := item.ParamConfig
st := item.State
s.optMap[p.Param.Name] = newOptimizer(p, st)
}
}
return s, nil
}
@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
return nil
}
@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil
}
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
// pserver save checkpoint
func (s *Service) doCheckpoint() error {
<-s.initialized
s.mu.Lock()
defer s.mu.Unlock()
cp := make([]ParameterCheckpoint, 0, len(s.optMap))
index := 0
for name, opt := range s.optMap {
var pc ParameterCheckpoint
pc.ParamConfig.Param.Name = name
pc.ParamConfig.Param.ElementType = opt.elementType
pc.ParamConfig.Param.Content = opt.GetWeights()
pc.State = opt.GetStates()
cp[index] = pc
index++
}
var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(cp)
if err != nil {
return err
}
// TODO
cpMeta := checkpointMeta{}
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
cpMeta.Timestamp = time.Now().String()
h := md5.New()
cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, _ := json.Marshal(cpMeta)
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
if err != nil {
return err
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
f, err := os.Create(cpMeta.UUID)
defer f.Close()
if err != nil {
return err
}
writer := bufio.NewWriter(f)
_, err = writer.Write(buf.Bytes())
writer.Flush()
if err != nil {
return err
}
return nil
}

@ -15,7 +15,8 @@ const (
)
func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
}
func TestMultipleInit(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
}
func TestUninitialized(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized {
t.FailNow()
@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
}
func TestBlockUntilInitialized(t *testing.T) {
s, err := pserver.NewService(0)
var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil {
t.Error(err)
}
@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{}
}()
wg.Add(1)
go func() {
err := s.Save("", nil)
if err != nil {
errCh <- err
}
wg.Done()
ch <- struct{}{}
}()
time.Sleep(50 * time.Millisecond)
select {
@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait()
}
func TestCheckpointSpeed(t *testing.T) {
//TODO(zhihong): test speed
}

@ -15,6 +15,8 @@ if(Boost_FOUND)
add_subdirectory(memory)
add_subdirectory(platform)
add_subdirectory(framework)
add_subdirectory(operators)
add_subdirectory(pybind)
endif()
if(WITH_C_API)

@ -11,8 +11,14 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
proto_library(op_desc SRCS op_desc.proto DEPS attr_type)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_proto op_desc)
cc_library(operator SRCS operator.cc DEPS op_desc protobuf)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator)
py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
proto_library(net_proto SRCS net_proto.proto DEPS op_proto)
cc_library(net SRCS net.cc DEPS net_proto)

@ -0,0 +1,20 @@
#include "paddle/framework/net.h"
namespace paddle {
namespace framework {
PlainNet::PlainNet(const NetDesc& def) {}
void PlainNet::InferShape(Scope* scope) {
for (auto& op : ops_) {
op.InferShape();
}
}
void PlainNet::Run(std::shared_ptr<Scope> scope, DeviceContext* ctx) {
for (auto& op : ops_) {
op.Run(ctx);
}
}
} // namespace framework
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save