Merge branch 'develop' into fix-import-bug

revert-4814-Add_sequence_project_op
Peng Li 7 years ago
commit a02ebbb5d8

@ -8,7 +8,7 @@ ExternalProject_Add(
extern_eigen3 extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG 4e79cb69b9425f5f8c3a84be4350d4ab75b5fd9d GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10
PREFIX ${EIGEN_SOURCE_DIR} PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""

@ -1,9 +1,8 @@
INCLUDE(ExternalProject) include(ExternalProject)
SET(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl) set(NCCL_SOURCE_DIR ${THIRD_PARTY_PATH}/nccl)
INCLUDE_DIRECTORIES(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
include_directories(${NCCL_SOURCE_DIR}/src/extern_nccl/src)
if(WITH_DSO) if(WITH_DSO)
# If we use DSO, we do not build nccl, just download the dependencies # If we use DSO, we do not build nccl, just download the dependencies
@ -12,39 +11,39 @@ if(WITH_DSO)
set(NCCL_INSTALL_DIR "") set(NCCL_INSTALL_DIR "")
else() else()
# otherwise, we build nccl and link it. # otherwise, we build nccl and link it.
set(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
# Note: cuda 8.0 is needed to make nccl
# When cuda is not installed on the system directory, need to set CUDA_HOME to your cuda root
set(NCCL_BUILD_COMMAND "make -j 8") set(NCCL_BUILD_COMMAND "make -j 8")
set(NCCL_INSTALL_COMMAND "make install") set(NCCL_INSTALL_COMMAND "make install PREFIX=${NCCL_INSTALL_DIR}")
SET(NCCL_INSTALL_DIR ${THIRD_PARTY_PATH}/install/nccl)
endif() endif()
ExternalProject_Add( ExternalProject_Add(
extern_nccl extern_nccl
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git" GIT_REPOSITORY "https://github.com/NVIDIA/nccl.git"
GIT_TAG "v1.3.4-1" GIT_TAG "v1.3.4-1"
PREFIX "${NCCL_SOURCE_DIR}" PREFIX "${NCCL_SOURCE_DIR}"
UPDATE_COMMAND "" UPDATE_COMMAND ""
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "${NCCL_BUILD_COMMAND}" BUILD_COMMAND "${NCCL_BUILD_COMMAND}"
INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}" INSTALL_COMMAND "${NCCL_INSTALL_COMMAND}"
INSTALL_DIR "${NCCL_INSTALL_DIR}" INSTALL_DIR "${NCCL_INSTALL_DIR}"
TEST_COMMAND "" TEST_COMMAND ""
) )
if (WITH_DSO) if(WITH_DSO)
if (${CMAKE_VERSION} VERSION_LESS "3.3.0") if(${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_any_dummy.c) set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/lib_nccl_dummy.c)
file(WRITE ${dummyfile} "const char * dummy_any = \"${dummyfile}\";") file(WRITE ${dummyfile} "const char * dummy_nccl = \"${dummyfile}\";")
add_library(nccl STATIC ${dummyfile}) add_library(nccl STATIC ${dummyfile})
else() else()
add_library(nccl INTERFACE) add_library(nccl INTERFACE)
endif() endif()
else() else()
ADD_LIBRARY(nccl STATIC IMPORTED GLOBAL) add_library(nccl STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET nccl PROPERTY IMPORTED_LOCATION set_property(TARGET nccl PROPERTY IMPORTED_LOCATION
${NCCL_INSTALL_DIR}/lib/libnccl.a) ${NCCL_INSTALL_DIR}/lib/libnccl_static.a)
endif() endif()
add_dependencies(nccl extern_nccl) add_dependencies(nccl extern_nccl)
LIST(APPEND external_project_dependencies nccl)

@ -0,0 +1,232 @@
## Survey on Graph
Neural network framework often provides symbolic API for users to write network topology conveniently. This doc manily focus on symbolic API in most popular neural network frameworks, and try to find out how to parse symbolic configuration to a portable file, such as protobuf or json.
### Mxnet
The core concept of symbolic API is `Symbol`. Mxnet implements `Symbol` class in C++, and export to Python using C-API. Please refer to the comments in Mxnet:
`Symbol` is help class used to represent the operator node in Graph.
`Symbol` acts as an interface for building graphs from different components like Variable, Functor and Group. `Symbol` is also exported to python front-end (while Graph is not) to enable quick test and deployment. Conceptually, symbol is the final operation of a graph and thus including all the information required (the graph) to evaluate its output value.
A simple network topology wrote by Symbol is as follows:
```python
def get_symbol(num_classes=10, **kwargs):
data = mx.symbol.Variable('data')
data = mx.symbol.Flatten(data=data)
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)
mlp = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
return mlp
```
Varible here is actually a Symbol. Every basic Symbol will correspond to one Node, and every Node has its own NodeAttr. There is a op field in NodeAttr class, when a Symbol represents Variable(often input data), the op field is null.
Symbol contains a data member, std::vector<NodeEntry> outputs, and NodeEntry cantains a poniter to Node. We can follow the Node pointer to get all the Graph.
And Symbol can be saved to a Json file.
Here is a detailed example:
```
>>> import mxnet as mx
>>> data = mx.symbol.Variable('data')
>>> print data.debug_str()
Variable:data
>>> data = mx.symbol.Flatten(data=data)
>>> print data.debug_str()
Symbol Outputs:
output[0]=flatten0(0)
Variable:data
--------------------
Op:Flatten, Name=flatten0
Inputs:
arg[0]=data(0) version=0
>>> fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
>>> print fc1.debug_str()
Symbol Outputs:
output[0]=fc1(0)
Variable:data
--------------------
Op:Flatten, Name=flatten0
Inputs:
arg[0]=data(0) version=0
Variable:fc1_weight
Variable:fc1_bias
--------------------
Op:FullyConnected, Name=fc1
Inputs:
arg[0]=flatten0(0)
arg[1]=fc1_weight(0) version=0
arg[2]=fc1_bias(0) version=0
Attrs:
num_hidden=128
```
### TensorFlow
The core concept of symbolic API is `Tensor`. Tensorflow defines `Tensor` in Python. Please refer to the comments in TensorFlow:
A `Tensor` is a symbolic handle to one of the outputs of an `Operation`. It does not hold the values of that operation's output, but instead provides a means of computing those values in a TensorFlow [Session](https://www.tensorflow.org/api_docs/python/tf/Session).
A simple example is as follows:
```python
# Build a dataflow graph.
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d)
# Construct a `Session` to execute the graph.
sess = tf.Session()
# Execute the graph and store the value that `e` represents in `result`.
result = sess.run(e)
```
The main method of `Tensor` is as follows:
```python
@property
def op(self):
"""The `Operation` that produces this tensor as an output."""
return self._op
@property
def dtype(self):
"""The `DType` of elements in this tensor."""
return self._dtype
@property
def graph(self):
"""The `Graph` that contains this tensor."""
return self._op.graph
@property
def name(self):
"""The string name of this tensor."""
if not self._op.name:
raise ValueError("Operation was not named: %s" % self._op)
return "%s:%d" % (self._op.name, self._value_index)
@property
def device(self):
"""The name of the device on which this tensor will be produced, or None."""
return self._op.device
```
Tensor can be taken as target to run by session. Tensor contains all the information of Graph, and tracks data dependency.
Here is a detailed example:
```
>>> import tensorflow as tf
>>> c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
>>> print c.graph
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
>>> d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
>>> print d.graph
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
>>> e = tf.matmul(c, d)
>>> print e.graph
<tensorflow.python.framework.ops.Graph object at 0x10f256d50>
```
### Dynet
The core concept of symbolic API is `Expression`, and Dynet defines `Expression` class in C++.
A simple example is as follows:
```cpp
ComputationGraph cg;
Expression W = parameter(cg, pW);
Expression in = input(cg, xs[i]);
Expression label = input(cg, ys[i]);
Expression pred = W * in;
Expression loss = square(pred - label);
```
The input data and parameter are also represented by Expression. Every basci Expression corresponds to a Node. And input data is also a Node.
Expression has a data member ComputationGraph, and ComputationGraph will be modified in users' configuring process. Expression can be a running target, beacuse Expression contains all dependency.
Here is a detailed example:
write topology in C++
```
ComputationGraph cg;
Expression W = parameter(cg, pW);
cg.print_graphviz();
Expression pred = W * xs[i];
cg.print_graphviz();
Expression loss = square(pred - ys[i]);
cg.print_graphviz();
```
compile and print
```
# first print
digraph G {
rankdir=LR;
nodesep=.05;
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
}
# second print
digraph G {
rankdir=LR;
nodesep=.05;
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
N1 [label="v1 = v0 * -0.98"];
N0 -> N1;
}
# third print
digraph G {
rankdir=LR;
nodesep=.05;
N0 [label="v0 = parameters({1}) @ 0x7ffe4de00110"];
N1 [label="v1 = v0 * -0.98"];
N0 -> N1;
N2 [label="v2 = -1.88387 - v1"];
N1 -> N2;
N3 [label="v3 = -v2"];
N2 -> N3;
N4 [label="v4 = square(v3)"];
N3 -> N4;
}
```
### Conclusion
Actually, Symbol/Tensor/Expression in Mxnet/TensorFlow/Dynet are the same level concepts. We use a unified name Expression here, this level concept has following features:
- Users wirte topoloy with symbolic API, and all return value is Expression, including input data and parameter.
- Expression corresponds with a global Graph, and Expression can also be composed.
- Expression tracks all dependency and can be taken as a run target

@ -0,0 +1,36 @@
# Design Doc: Model Format
## Motivation
The model is the output of training process. One complete model consists of two parts, namely, the **topology** and the **parameters**. To support industrial deployment, we need to make the model format must be self-completed and do not expose any training source code.
As a result, In PaddlePaddle, the **topology** represents as a [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/doc/design/program.md), which describes the model structure. The **parameters** contain all the trainable weights in the model, we must support large size parameter, and efficient serialization/deserialization.
## Implementation
The topology is saved as a plain text, in detail, a self-contain protobuf file.
The parameters are saved as a binary file. As we all know, the protobuf message has the limits of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We do a (benchmark experiment)[https://github.com/PaddlePaddle/Paddle/pull/4610], its result shows protobuf is not fit in this scene.
As a result, we design a particular format for tensor serialization. By default, arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of (LoDTensorDesc)[https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99]. We save the DescProto as the byte string header, it contains the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). Tensor stores value in a continuous memory buffer, for speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
|HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**|
In detail, tensor's byte view as the table shows. Note that all the signed value written in little-endian.
```text
[offset] [type] [description]
0004 4 bytes integer HeaderLength, the length of LoDTensorDesc
0008 4 bytes integer ContentLength, the length of LodTensor Buffer
0009 1 bytes char TensorDesc
00010 1 bytes char TensorDesc
...
00100 1 bytes char TensorValue
00101 1 bytes char TensorValue
00102 1 bytes char TensorValue ..
...
```
## Summary
We introduce the model format, the `ProgramDesc` describe the **topology**, and a bunch of particular format binary tensors describes the **parameters**.

@ -65,20 +65,6 @@ class Optimizer(object):
def __init__(self): def __init__(self):
pass pass
def create_backward_pass(self, loss, parameter_list=None):
"""
create and add gradient Operators in BlockDesc to Compute gradients of `loss`
for parameters in parameter_list
Args:
loss: an variable generated by cost function.
parameter_list: parameters that need to compute gradient and update to optimize the lost.
Returns:
list of (parameters, gradients) pair.
"""
return None
def create_optimization_pass(self, parameters_and_grads): def create_optimization_pass(self, parameters_and_grads):
"""Add optimization operators to update gradients to variables. """Add optimization operators to update gradients to variables.
@ -93,7 +79,7 @@ class Optimizer(object):
def minimize(self, loss, parameter_list): def minimize(self, loss, parameter_list):
"""Add operations to minimize `loss` by updating `parameter_list`. """Add operations to minimize `loss` by updating `parameter_list`.
This method combines interface `create_backward_pass()` and This method combines interface `append_backward_ops()` and
`create_optimization_pass()` into one. `create_optimization_pass()` into one.
""" """
params_grads = self.create_backward_pass(loss, parameter_list) params_grads = self.create_backward_pass(loss, parameter_list)

@ -25,9 +25,8 @@ import (
"strings" "strings"
"time" "time"
log "github.com/inconshreveable/log15"
"github.com/namsral/flag" "github.com/namsral/flag"
log "github.com/sirupsen/logrus"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
@ -41,16 +40,20 @@ func main() {
taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task-timeout-max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk-per-task", 10, "chunk per task.")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warn, error, crit")
flag.Parse() flag.Parse()
level, e := log.ParseLevel(*logLevel) lvl, err := log.LvlFromString(*logLevel)
candy.Must(e) if err != nil {
panic(err)
}
log.SetLevel(level) log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
if *endpoints == "" { if *endpoints == "" {
log.Warningln("-endpoints not set, fault tolerance not be enabled.") log.Warn("-endpoints not set, fault tolerance not be enabled.")
} }
var store master.Store var store master.Store
@ -58,23 +61,25 @@ func main() {
eps := strings.Split(*endpoints, ",") eps := strings.Split(*endpoints, ",")
ip, err := networkhelper.GetExternalIP() ip, err := networkhelper.GetExternalIP()
if err != nil { if err != nil {
log.Fatal(err) log.Crit("get external ip error", log.Ctx{"error": err})
panic(err)
} }
addr := fmt.Sprintf("%s:%d", ip, *port) addr := fmt.Sprintf("%s:%d", ip, *port)
store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec) store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error creating etcd client.", log.Ctx{"error": err})
panic(err)
} }
} else { } else {
store = &master.InMemStore{} store = &master.InMemStore{}
} }
shutdown := func() { shutdown := func() {
log.Infoln("shutting down gracefully") log.Info("shutting down gracefully")
err := store.Shutdown() err := store.Shutdown()
if err != nil { if err != nil {
log.Errorln(err) log.Error("shutdown error", log.Ctx{"error": err})
} }
} }
@ -86,24 +91,28 @@ func main() {
s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error creating new service.", log.Ctx{"error": err})
panic(err)
} }
err = rpc.Register(s) err = rpc.Register(s)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error registering to etcd.", log.Ctx{"error": err})
panic(err)
} }
rpc.HandleHTTP() rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error listing to port", log.Ctx{"error": err, "port": *port})
panic(err)
} }
go func() { go func() {
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
log.Fatal(err) log.Crit("error serving HTTP", log.Ctx{"error": err})
panic(err)
} }
}() }()

@ -27,11 +27,11 @@ import (
"github.com/topicai/candy" "github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 8001, "port of the pserver")
index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry") index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
@ -41,13 +41,17 @@ func main() {
checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path")
checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warn, error, crit")
flag.Parse() flag.Parse()
level, err := log.ParseLevel(*logLevel) lvl, err := log.LvlFromString(*logLevel)
candy.Must(err) if err != nil {
panic(err)
}
log.SetLevel(level) log.Root().SetHandler(
log.LvlFilterHandler(lvl, log.CallerStackHandler("%+v", log.StderrHandler)),
)
var idx int var idx int
@ -63,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx) cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil { if err != nil {
if err == pserver.ErrCheckpointNotFound { if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.") log.Info("Could not find the pserver checkpoint.")
} else { } else {
panic(err) panic(err)
} }
@ -71,10 +75,10 @@ func main() {
} }
shutdown := func() { shutdown := func() {
log.Infoln("shutting down gracefully") log.Info("shutting down gracefully")
sErr := e.Shutdown() sErr := e.Shutdown()
if sErr != nil { if sErr != nil {
log.Errorln(sErr) log.Error("error shutting down", log.Ctx{"error": sErr})
} }
} }
@ -95,7 +99,7 @@ func main() {
candy.Must(err) candy.Must(err)
go func() { go func() {
log.Infof("start pserver at port %d", *port) log.Info("starting pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil) err = http.Serve(l, nil)
candy.Must(err) candy.Must(err)
}() }()

16
go/glide.lock generated

@ -1,5 +1,5 @@
hash: 328e7b9b7306b45e7b9879139a9f86698115981f6283032e1312093a6a6ddb04 hash: 51d9e2e46d7fd9173ff11ecada40f7b7728756be18d5e2f032535f66465e6e15
updated: 2017-10-16T08:00:23.484693528Z updated: 2017-10-24T15:04:09.987751592-07:00
imports: imports:
- name: github.com/alecthomas/gometalinter - name: github.com/alecthomas/gometalinter
version: bae2f1293d092fd8167939d5108d1b025eaef9de version: bae2f1293d092fd8167939d5108d1b025eaef9de
@ -99,6 +99,8 @@ imports:
version: d2709f9f1f31ebcda9651b03077758c1f3a0018c version: d2709f9f1f31ebcda9651b03077758c1f3a0018c
- name: github.com/ghodss/yaml - name: github.com/ghodss/yaml
version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7 version: 0ca9ea5df5451ffdf184b4428c902747c2c11cd7
- name: github.com/go-stack/stack
version: 817915b46b97fd7bb80e8ab6b69f01a53ac3eebf
- name: github.com/gogo/protobuf - name: github.com/gogo/protobuf
version: 909568be09de550ed094403c2bf8a261b5bb730a version: 909568be09de550ed094403c2bf8a261b5bb730a
subpackages: subpackages:
@ -120,8 +122,14 @@ imports:
- runtime - runtime
- runtime/internal - runtime/internal
- utilities - utilities
- name: github.com/inconshreveable/log15
version: 0decfc6c20d9ca0ad143b0e89dcaa20f810b4fb3
- name: github.com/jonboulle/clockwork - name: github.com/jonboulle/clockwork
version: 2eee05ed794112d45db504eb05aa693efd2b8b09 version: 2eee05ed794112d45db504eb05aa693efd2b8b09
- name: github.com/mattn/go-colorable
version: 5411d3eea5978e6cdc258b30de592b60df6aba96
- name: github.com/mattn/go-isatty
version: 57fdcb988a5c543893cc61bce354a6e24ab70022
- name: github.com/matttproud/golang_protobuf_extensions - name: github.com/matttproud/golang_protobuf_extensions
version: c12348ce28de40eed0136aa2b644d0ee0650e56c version: c12348ce28de40eed0136aa2b644d0ee0650e56c
subpackages: subpackages:
@ -179,11 +187,12 @@ imports:
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sys - name: golang.org/x/sys
version: 0f826bdd13b500be0f1d4004938ad978fcc6031e version: e48874b42435b4347fc52bdee0424a52abc974d7
repo: https://github.com/golang/sys.git repo: https://github.com/golang/sys.git
vcs: git vcs: git
subpackages: subpackages:
- unix - unix
- windows
- name: golang.org/x/text - name: golang.org/x/text
version: 836efe42bb4aa16aaa17b9c155d8813d336ed720 version: 836efe42bb4aa16aaa17b9c155d8813d336ed720
repo: https://github.com/golang/text.git repo: https://github.com/golang/text.git
@ -222,4 +231,3 @@ testImports:
version: 05e8a0eda380579888eb53c394909df027f06991 version: 05e8a0eda380579888eb53c394909df027f06991
subpackages: subpackages:
- assert - assert

@ -26,3 +26,7 @@ import:
version: v1.1.0 version: v1.1.0
- package: github.com/alecthomas/gometalinter - package: github.com/alecthomas/gometalinter
version: v1.2.1 version: v1.2.1
- package: github.com/inconshreveable/log15
version: v2.13
- package: github.com/go-stack/stack
version: v1.6.0

@ -35,13 +35,19 @@ import (
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client) var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client var curHandle C.paddle_master_client
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
func add(c *master.Client) C.paddle_master_client { func add(c *master.Client) C.paddle_master_client {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
@ -117,7 +123,7 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
} }
err := c.SetDataset(paths) err := c.SetDataset(paths)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error set dataset", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR return C.PADDLE_MASTER_ERROR
} }
@ -167,7 +173,7 @@ func paddle_request_save_model(client C.paddle_master_client, trainerID string,
c := get(client) c := get(client)
need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond) need, err := c.RequestSaveModel(trainerID, time.Duration(blockMS)*time.Millisecond)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error request save model", log.Ctx{"error": err})
return C.PADDLE_MASTER_ERROR return C.PADDLE_MASTER_ERROR
} }

@ -21,7 +21,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// Client is the client of the master server. // Client is the client of the master server.
@ -75,7 +75,7 @@ func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
for { for {
err := f() err := f()
if err != nil { if err != nil {
log.Warningln(err) log.Warn("create etcd client error", log.Ctx{"error": err})
} else { } else {
break break
} }
@ -135,13 +135,13 @@ func (c *Client) getRecords(passID int) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
continue continue
} }
log.Errorf("getTask error: %s", err) log.Error("getTask error.", log.Ctx{"error": err})
} }
for _, chunk := range t.Chunks { for _, chunk := range t.Chunks {
f, e := os.Open(chunk.Path) f, e := os.Open(chunk.Path)
if e != nil { if e != nil {
log.Errorln(e) log.Error("error open chunk", log.Ctx{"error": e})
continue continue
} }
@ -152,12 +152,15 @@ func (c *Client) getRecords(passID int) {
if s.Err() != nil { if s.Err() != nil {
c.ch <- record{nil, s.Err()} c.ch <- record{nil, s.Err()}
log.Errorln(err, chunk.Path) log.Error(
"error scan chunk",
log.Ctx{"error": err, "path": chunk.Path},
)
} }
err = f.Close() err = f.Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error close record file", log.Ctx{"error": err})
} }
} }
@ -166,7 +169,7 @@ func (c *Client) getRecords(passID int) {
// correct, but a reasonable approximation. // correct, but a reasonable approximation.
err = c.taskFinished(t.Meta.ID) err = c.taskFinished(t.Meta.ID)
if err != nil { if err != nil {
log.Errorln(err) log.Error("task finish callback error.", log.Ctx{"error": err})
} }
} }
} }
@ -179,12 +182,12 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
if curMaster == "" { if curMaster == "" {
err := c.conn.Close() err := c.conn.Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("close old master addr error", log.Ctx{"error": err})
} }
} else { } else {
err := c.conn.Connect(curMaster) err := c.conn.Connect(curMaster)
if err != nil { if err != nil {
log.Errorln(err) log.Error("connect to new master addr error", log.Ctx{"error": err})
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order

@ -25,8 +25,6 @@ import (
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
@ -36,10 +34,6 @@ const (
chunkPerTask = 10 chunkPerTask = 10
) )
func init() {
log.SetLevel(log.ErrorLevel)
}
func TestGetFinishTask(t *testing.T) { func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0" const path = "/tmp/master_client_test_0"

@ -20,7 +20,7 @@ import (
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
@ -44,7 +44,7 @@ type EtcdClient struct {
// NewEtcdClient creates a new EtcdClient. // NewEtcdClient creates a new EtcdClient.
func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) {
log.Debugf("Connecting to etcd at %v", endpoints) log.Debug("Connecting to etcd", log.Ctx{"endpoint": endpoints})
cli, err := clientv3.New(clientv3.Config{ cli, err := clientv3.New(clientv3.Config{
Endpoints: endpoints, Endpoints: endpoints,
DialTimeout: dialTimeout, DialTimeout: dialTimeout,
@ -64,12 +64,12 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
// one master running, but split-brain problem may cause // one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management // multiple master servers running), and the cluster management
// software will kill one of them. // software will kill one of them.
log.Infof("Trying to acquire lock at %s.", lockPath) log.Info("Trying to acquire lock.", log.Ctx{"path": lockPath})
err = lock.Lock(context.TODO()) err = lock.Lock(context.TODO())
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Infof("Successfully acquired lock at %s.", lockPath) log.Info("Successfully acquired lock at %s.", log.Ctx{"path": lockPath})
put := clientv3.OpPut(addrPath, addr) put := clientv3.OpPut(addrPath, addr)
resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit()
@ -78,7 +78,8 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Fatal("No longer owns the master lock. Exiting.") log.Crit("No longer owns the master lock. Exiting.")
panic("No longer owns the master lock. Exiting.")
} }
e := &EtcdClient{ e := &EtcdClient{
@ -102,7 +103,7 @@ func (e *EtcdClient) Save(state []byte) error {
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock again") log.Error("No longer owns the lock, trying to lock again")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
err := e.lock.Lock(ctx) err := e.lock.Lock(ctx)
cancel() cancel()
@ -116,9 +117,10 @@ func (e *EtcdClient) Save(state []byte) error {
// to kill current master server. The current // to kill current master server. The current
// state is not saved, but the trainer's RPC // state is not saved, but the trainer's RPC
// call will fail, so the trainer will retry. // call will fail, so the trainer will retry.
log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err) log.Crit("Could not acquire the lock at %s: %v. Exiting.", log.Ctx{"path": e.lockPath, "error": err})
panic("Could not acquire the lock at %s: %v. Exiting.")
} }
log.Infof("Successfully acquired lock at %s.", e.lockPath) log.Info("Successfully acquired lock at %s.", e.lockPath)
return e.Save(state) return e.Save(state)
} }
@ -136,7 +138,7 @@ func (e *EtcdClient) Load() ([]byte, error) {
} }
if !resp.Succeeded { if !resp.Succeeded {
log.Errorln("No longer owns the lock, trying to lock and load again.") log.Error("No longer owns the lock, trying to lock and load again.")
err = e.lock.Lock(context.Background()) err = e.lock.Lock(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
@ -163,7 +165,7 @@ func (e *EtcdClient) Shutdown() error {
if err == nil { if err == nil {
err = newErr err = newErr
} else { } else {
log.Errorln(newErr) log.Error("shutdown error", log.Ctx{"error": newErr})
} }
} }
@ -192,7 +194,7 @@ func watchKey(c *clientv3.Client, key string, valChan chan<- string) {
for wresp := range rch { for wresp := range rch {
for _, ev := range wresp.Events { for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string // if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value) log.Info("received event.", log.Ctx{"type": ev.Type, "key": ev.Kv.Key, "value": ev.Kv.Value})
valChan <- string(ev.Kv.Value) valChan <- string(ev.Kv.Value)
} }
} }

@ -25,7 +25,7 @@ import (
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
@ -170,11 +170,11 @@ func (s *Service) recover() (bool, error) {
} }
if state == nil { if state == nil {
log.Infoln("No state exists, not recovered.") log.Info("No state exists, not recovered.")
return false, nil return false, nil
} }
log.Infof("Loaded snapshot of size: %d bytes.", len(state)) log.Info("Loaded snapshot.", log.Ctx{"size": len(state)})
gr, err := gzip.NewReader(bytes.NewReader(state)) gr, err := gzip.NewReader(bytes.NewReader(state))
if err != nil { if err != nil {
return false, err return false, err
@ -191,11 +191,11 @@ func (s *Service) recover() (bool, error) {
if err != nil { if err != nil {
// Only close failed, recover actually succeed, so // Only close failed, recover actually succeed, so
// just log error. // just log error.
log.Errorln(err) log.Error("error close recover file.", log.Ctx{"error": err})
} }
s.state = tqs s.state = tqs
log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.") log.Info("Master recovered from snapshot, scheduling pending task timeout check.", s.logCtx())
for _, t := range s.state.Pending { for _, t := range s.state.Pending {
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
} }
@ -224,7 +224,7 @@ func (s *Service) snapshot() error {
} }
state := buf.Bytes() state := buf.Bytes()
log.Infof("Saving snapshot of size: %d bytes.", len(state)) log.Info("Saving snapshot.", log.Ctx{"size bytes": len(state)})
return s.store.Save(state) return s.store.Save(state)
} }
@ -260,7 +260,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
} }
count := index.NumChunks() count := index.NumChunks()
log.Infof("readChunks: file %s has %d chunks", path, count) log.Info("reading chunks.", log.Ctx{"path": path, "num chunks": count})
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
chunk := Chunk{ chunk := Chunk{
Path: path, Path: path,
@ -300,7 +300,7 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
return err return err
} }
close(s.ready) close(s.ready)
@ -320,7 +320,7 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
defer func() { defer func() {
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
} }
}() }()
@ -328,12 +328,12 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
t.NumFailure++ t.NumFailure++
if t.NumFailure > s.failureMax { if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) log.Warn("Task failed to many times, discard.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Failed = append(s.state.Failed, t) s.state.Failed = append(s.state.Failed, t)
return return
} }
log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure) log.Warn("Task failed, re-dispatch.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
s.state.Todo = append(s.state.Todo, t) s.state.Todo = append(s.state.Todo, t)
return return
} }
@ -353,8 +353,8 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
} }
// must be called with lock held. // must be called with lock held.
func (s *Service) logFields() log.Fields { func (s *Service) logCtx() log.Ctx {
return log.Fields{ return log.Ctx{
"todoLen": len(s.state.Todo), "todoLen": len(s.state.Todo),
"pendingLen": len(s.state.Pending), "pendingLen": len(s.state.Pending),
"doneLen": len(s.state.Done), "doneLen": len(s.state.Done),
@ -383,10 +383,10 @@ func (s *Service) GetTask(passID int, task *Task) error {
if len(s.state.Todo) == 0 { if len(s.state.Todo) == 0 {
if len(s.state.Done) == 0 && len(s.state.Pending) == 0 { if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass") log.Warn("All tasks failed, may start next pass", s.logCtx())
return ErrAllTaskFailed return ErrAllTaskFailed
} }
log.WithFields(s.logFields()).Warningln("No more available task.") log.Warn("No more available task.", s.logCtx())
return ErrNoMoreAvailable return ErrNoMoreAvailable
} }
@ -400,8 +400,9 @@ func (s *Service) GetTask(passID int, task *Task) error {
} }
*task = t.Task *task = t.Task
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta) ctx := s.logCtx()
ctx["task meta"] = t.Task.Meta
log.Info("Task dispatched.", ctx)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil return nil
} }
@ -417,7 +418,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t, ok := s.state.Pending[taskID] t, ok := s.state.Pending[taskID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) ctx := s.logCtx()
ctx["task id"] = taskID
log.Warn("Pending task not found.", ctx)
return nil return nil
} }
@ -426,7 +429,9 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.state.Done = append(s.state.Done, t) s.state.Done = append(s.state.Done, t)
delete(s.state.Pending, taskID) delete(s.state.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) ctx := s.logCtx()
ctx["task id"] = taskID
log.Info("Task finished.", ctx)
if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 { if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
// increase master side pass count if all tasks finished // increase master side pass count if all tasks finished
s.state.CurPass++ s.state.CurPass++
@ -434,12 +439,14 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.state.Done = []taskEntry{} s.state.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks // TODO(typhoonzero): deal with failed tasks
s.state.Failed = []taskEntry{} s.state.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass) ctx := s.logCtx()
ctx["new pass"] = s.state.CurPass
log.Warn("all task finished, add new pass data.", ctx)
} }
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
log.Errorln(err) log.Error("snapshot error", log.Ctx{"error": err})
} }
return err return err
} }
@ -455,7 +462,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
t, ok := s.state.Pending[meta.ID] t, ok := s.state.Pending[meta.ID]
if !ok { if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta) log.Warn("TaskFailed:Pending task not found.", log.Ctx{"task": t.Task.Meta})
return nil return nil
} }

@ -45,9 +45,15 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client" "github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
func init() {
log.Root().SetHandler(
log.LvlFilterHandler(log.LvlWarn, log.CallerStackHandler("%+v", log.StderrHandler)),
)
}
var mu sync.Mutex var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*client.Client) var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client var curHandle C.paddle_pserver_client
@ -164,10 +170,13 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter,
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningf("parameter %s already initialized, treat paddle_init_param as successful.", name) log.Warn(
"parameter already initialized, treat paddle_init_param as successful.",
log.Ctx{"parameter": name},
)
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Error("error init param", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
@ -180,11 +189,11 @@ func paddle_finish_init_params(client C.paddle_pserver_client) C.int {
err := c.FinishInitParams() err := c.FinishInitParams()
if err != nil { if err != nil {
if err.Error() == pserver.AlreadyInitialized { if err.Error() == pserver.AlreadyInitialized {
log.Warningln("parameters already initialized, treat paddle_finish_init_params as successful.") log.Warn("parameters already initialized, treat paddle_finish_init_params as successful.")
return C.PSERVER_OK return C.PSERVER_OK
} }
log.Errorln(err) log.Error("error finish init params", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
@ -205,7 +214,7 @@ func paddle_send_grads(client C.paddle_pserver_client, grads **C.paddle_gradient
c := get(client) c := get(client)
err := c.SendGrads(gs) err := c.SendGrads(gs)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error send grads", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
@ -222,7 +231,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
c := get(client) c := get(client)
ps, err := c.GetParams(ns) ps, err := c.GetParams(ns)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error get params", log.Ctx{"error": err})
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
@ -231,7 +240,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Error(
"pserver returned wrong number of parameters.",
log.Ctx{
"Requested": strings.Join(pn, ", "),
"Returned": strings.Join(ns, ", "),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
@ -241,7 +256,13 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
for i, p := range ps { for i, p := range ps {
pn[i] = p.Name pn[i] = p.Name
} }
log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) log.Error(
"pserver returned wrong parameters, or not in requested order.",
log.Ctx{
"Requested": strings.Join(pn, ", "),
"Returned": strings.Join(ns, ", "),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }
@ -251,13 +272,19 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter,
param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst)))) param := *(**C.paddle_parameter)(unsafe.Pointer((uintptr(unsafe.Pointer(dst)) + uintptr(i)*unsafe.Sizeof(*dst))))
if unsafe.Pointer(param) == nil { if unsafe.Pointer(param) == nil {
log.Errorln("must pre-allocate parameter.") log.Error("must pre-allocate parameter.")
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
if unsafe.Pointer(param.content) != nil { if unsafe.Pointer(param.content) != nil {
if int(param.content_len) != len(p.Content) { if int(param.content_len) != len(p.Content) {
log.Errorf("the pre-allocated content len does not match parameter content len. Pre-allocated len: %d, returned len: %d", param.content_len, len(p.Content)) log.Error(
"the pre-allocated content len does not match parameter content len.",
log.Ctx{
"Pre-allocated len": param.content_len,
"Returned len": len(p.Content),
},
)
return C.PSERVER_ERROR return C.PSERVER_ERROR
} }
} }

@ -22,7 +22,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
@ -84,7 +84,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
if curServers[i].Addr == "" { if curServers[i].Addr == "" {
err := c.pservers[i].Close() err := c.pservers[i].Close()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error closing connection to pserver", log.Ctx{"error": err})
} }
continue continue
@ -92,7 +92,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
err := c.pservers[i].Connect(curServers[i].Addr) err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil { if err != nil {
log.Errorln(err) log.Error("error connecting to pserver", log.Ctx{"error": err})
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order

@ -30,7 +30,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client" "github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
@ -90,7 +90,7 @@ func initEtcdClient() {
DialTimeout: time.Second * time.Duration(1), DialTimeout: time.Second * time.Duration(1),
}) })
if err != nil { if err != nil {
log.Errorf("err %v", err) log.Error("error init etcd client", log.Ctx{"error": err})
} }
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err = client.Delete(ctx, pserver.PsDesired) _, err = client.Delete(ctx, pserver.PsDesired)

@ -25,7 +25,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
@ -54,26 +54,29 @@ func (e *Etcd) Desired() int {
resp, err := e.client.Get(ctx, pserver.PsDesired) resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) log.Error(
"Get ps dresire number failed! reconnecting...",
log.Ctx{"error": err},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...") log.Info("Waiting for ps desired registered ...")
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("psDesired %d invalid %v", psDesired, err) log.Error("atoi failed", log.Ctx{"error": err})
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("Get psDesired number: %d", psDesired) log.Debug("Got psDesired", log.Ctx{"psDesired": psDesired})
break break
} }
return psDesired return psDesired
@ -88,17 +91,20 @@ func (e *Etcd) List() []Server {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debug("looking for pserver", log.Ctx{"ps key": psKey})
resp, err := e.client.Get(ctx, psKey) resp, err := e.client.Get(ctx, psKey)
cancel() cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Info(
"Get psKey error",
log.Ctx{"ps key": psKey, "error": err},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...") log.Info("Waiting for ps addr registered ...")
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
@ -106,11 +112,17 @@ func (e *Etcd) List() []Server {
psAddr := string(resp.Kvs[0].Value) psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address // TODO(Longfei) check the ps address
if psAddr == "" { if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey) log.Info(
"Value under psKey is empty",
log.Ctx{"psKey": psKey},
)
time.Sleep(e.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("got value (%s) for key: %s", psAddr, psKey) log.Debug(
"got psAddr given psKey",
log.Ctx{"psAddr": psAddr, "psKey": psKey},
)
servers[i].Index = i servers[i].Index = i
servers[i].Addr = psAddr servers[i].Addr = psAddr
} }
@ -130,13 +142,13 @@ func NewEtcd(endpoints string) *Etcd {
DialTimeout: defaultEtcdTimeout, DialTimeout: defaultEtcdTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("Init etcd connection failed: %v", err) log.Error("Init etcd connection failed", log.Ctx{"error": err})
time.Sleep(defaultEtcdTimeout) time.Sleep(defaultEtcdTimeout)
continue continue
} }
break break
} }
log.Infof("Connected to etcd: %s\n", endpoints) log.Info("Connected to etcd endpoint", log.Ctx{"endpoint": endpoints})
client := &Etcd{ client := &Etcd{
client: cli, client: cli,
timeout: defaultEtcdTimeout, timeout: defaultEtcdTimeout,
@ -154,7 +166,7 @@ func (e *Etcd) Select() (bool, error) {
} }
lock := concurrency.NewMutex(sess, initLockPath) lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath) log.Info("Trying to acquire lock", log.Ctx{"lock path": initLockPath})
// Do not use timeout context here, since we don't know how // Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the // long does it take for other trainers to initialize the
// parameters. // parameters.
@ -162,7 +174,7 @@ func (e *Etcd) Select() (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
log.Infof("Successfully acquired lock at %s.", initLockPath) log.Info("Successfully acquired lock", log.Ctx{"lock path": initLockPath})
get := clientv3.OpGet(initDonePath) get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
@ -181,17 +193,17 @@ func (e *Etcd) Select() (bool, error) {
if len(resp.Kvs) == 0 { if len(resp.Kvs) == 0 {
// Key value not set, select current trainer. // Key value not set, select current trainer.
e.lock = lock e.lock = lock
log.Infoln("Trainer selected.") log.Info("Trainer selected.")
return true, nil return true, nil
} }
if string(resp.Kvs[0].Value) == initDoneVal { if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.") log.Info("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout) ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx) err = lock.Unlock(ctx)
cancel() cancel()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error unlocking", log.Ctx{"error": err})
} }
return false, nil return false, nil
} }
@ -221,7 +233,7 @@ func (e *Etcd) Done() error {
err = e.lock.Unlock(ctx) err = e.lock.Unlock(ctx)
cancel() cancel()
if err != nil { if err != nil {
log.Errorln(err) log.Error("error unlocking", log.Ctx{"error": err})
} else { } else {
e.lock = nil e.lock = nil
} }
@ -244,7 +256,7 @@ func (e *Etcd) Close() error {
cErr := e.client.Close() cErr := e.client.Close()
if cErr != nil { if cErr != nil {
if err != nil { if err != nil {
log.Errorln(cErr) log.Error("error closing etcd client", log.Ctx{"error": cErr})
return err return err
} }
return cErr return cErr

@ -24,7 +24,7 @@ import (
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency" "github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
const ( const (
@ -82,19 +82,19 @@ func (e *EtcdClient) Register(port int) (int, error) {
DialTimeout: e.dialTimeout, DialTimeout: e.dialTimeout,
}) })
if err != nil { if err != nil {
log.Errorf("connect to etcd error: %v", err) log.Error("connect to etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
e.client = cli e.client = cli
sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec)) sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
if err != nil { if err != nil {
log.Errorf("create etcd session error: %v", err) log.Error("create etcd session error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
e.sess = sess e.sess = sess
log.Debugf("inited client to %s", e.endpoints) log.Debug("connected to etcd", log.Ctx{"endpoint": e.endpoints})
break break
} }
// init /ps_desired using transaction, for multiple pservers may want to write // init /ps_desired using transaction, for multiple pservers may want to write
@ -104,7 +104,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
_, err := e.initDesiredPservers(ctx, e.numPservers) _, err := e.initDesiredPservers(ctx, e.numPservers)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn("pserver init error", log.Ctx{"error": err, "num pservers": e.numPservers})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
@ -119,14 +119,17 @@ func (e *EtcdClient) Register(port int) (int, error) {
resp, err := e.client.Get(ctx, PsDesired) resp, err := e.client.Get(ctx, PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err) log.Error("get etcd key error", log.Ctx{"key": PsDesired, "error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
if len(resp.Kvs) != 0 { if len(resp.Kvs) != 0 {
e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err) log.Error(
"psDesired atoi error",
log.Ctx{"error": err, "value": string(resp.Kvs[0].Value)},
)
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
// NOTE: wait util ps_desired value change // NOTE: wait util ps_desired value change
continue continue
@ -143,7 +146,7 @@ func (e *EtcdClient) Register(port int) (int, error) {
pserverIdx, err = e.registerPserverEtcd(ctx, port) pserverIdx, err = e.registerPserverEtcd(ctx, port)
cancel() cancel()
if err != nil { if err != nil {
log.Warn(err) log.Warn("register pserver on etcd error", log.Ctx{"error": err})
time.Sleep(retryTimeout) time.Sleep(retryTimeout)
continue continue
} }
@ -170,16 +173,17 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, er
registered := false registered := false
for i := 0; i < e.desired; i++ { for i := 0; i < e.desired; i++ {
psKey := PsPath + strconv.Itoa(i) psKey := PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey) ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey) log.Debug(
"register pserver got value",
log.Ctx{"value": ps, "key": psKey},
)
if ps == "" { if ps == "" {
// find the first id and write info // find the first id and write info
pserverAddr := e.externalIP + ":" + strconv.Itoa(port) pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease())) c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
log.Debugf("set pserver node %s with value %s", psKey, pserverAddr) log.Debug("register finished", log.Ctx{"key": psKey, "value": pserverAddr})
log.Debug("register finished")
idx = i idx = i
registered = true registered = true
break break
@ -239,7 +243,7 @@ func (e *EtcdClient) Shutdown() error {
newErr := e.client.Close() newErr := e.client.Close()
if newErr != nil { if newErr != nil {
if err != nil { if err != nil {
log.Errorln(newErr) log.Error("shutdown error", log.Ctx{"error": newErr})
} else { } else {
err = newErr err = newErr
} }

@ -25,7 +25,7 @@ import (
"fmt" "fmt"
"unsafe" "unsafe"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
type optimizer struct { type optimizer struct {
@ -56,12 +56,12 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
c := paramWithConfigs.Config c := paramWithConfigs.Config
s := State s := State
paramBufferSize := C.size_t(len(p.Content)) paramBufferSize := C.size_t(len(p.Content))
log.WithFields(log.Fields{ log.Info("New Optimizer Created with config", log.Ctx{
"ElementType": p.ElementType, "ElementType": p.ElementType,
"ParamSize": paramBufferSize, "ParamSize": paramBufferSize,
"ConfigSize": len(c), "ConfigSize": len(c),
"StateSize": len(s), "StateSize": len(s),
}).Info("New Optimizer Created with config:") })
var cbuffer unsafe.Pointer var cbuffer unsafe.Pointer
cbuffer = C.malloc(paramBufferSize) cbuffer = C.malloc(paramBufferSize)
@ -72,21 +72,34 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
} }
o.config = c o.config = c
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer(
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s))) (*C.uchar)(&c[0]),
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
C.int(paramBufferSize),
(*C.char)(cstate),
C.int(len(s)),
)
return o return o
} }
func (o *optimizer) GetWeights() []byte { func (o *optimizer) GetWeights() []byte {
var buffer unsafe.Pointer var buffer unsafe.Pointer
// we do not own the buffer, no need to free later.
bufferLen := C.paddle_optimizer_get_weights(o.opt, &buffer) bufferLen := C.paddle_optimizer_get_weights(o.opt, &buffer)
return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float) return cArrayToSlice(buffer, int(bufferLen)*C.sizeof_float)
} }
func (o *optimizer) GetStates() []byte { func (o *optimizer) GetStates() []byte {
var cbuffer *C.char var cbuffer *C.char
// we owns the state buffer, need to free later.
cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer) cbufferLen := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen)) buf := cArrayToSlice(unsafe.Pointer(cbuffer), int(cbufferLen))
cpy := make([]byte, len(buf))
copy(cpy, buf)
C.free(unsafe.Pointer(cbuffer))
return cpy
} }
func (o *optimizer) UpdateParameter(g Gradient) error { func (o *optimizer) UpdateParameter(g Gradient) error {

@ -15,8 +15,12 @@
package pserver package pserver
import ( import (
"encoding/binary"
"io/ioutil" "io/ioutil"
"math"
"testing" "testing"
"github.com/stretchr/testify/assert"
) )
func TestOptimizerCreateRelease(t *testing.T) { func TestOptimizerCreateRelease(t *testing.T) {
@ -36,3 +40,39 @@ func TestOptimizerCreateRelease(t *testing.T) {
o := newOptimizer(param, nil) o := newOptimizer(param, nil)
o.Cleanup() o.Cleanup()
} }
func float32Bytes(float float32) []byte {
bits := math.Float32bits(float)
bytes := make([]byte, 4)
binary.LittleEndian.PutUint32(bytes, bits)
return bytes
}
func TestOptimizerState(t *testing.T) {
p := Parameter{
Name: "a",
ElementType: Int32,
}
weights := float32Bytes(100)
p.Content = weights
config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
param := ParameterWithConfig{
Param: p,
Config: config,
}
o := newOptimizer(param, nil)
s := o.GetStates()
// clear param content and check if the state is restored.
param.Param.Content = float32Bytes(300)
o1 := newOptimizer(param, s)
s1 := o1.GetStates()
assert.Equal(t, s, s1)
assert.Equal(t, weights, o.GetWeights())
assert.Equal(t, weights, o1.GetWeights())
o.Cleanup()
o1.Cleanup()
}

@ -32,7 +32,7 @@ import (
uuid "github.com/satori/go.uuid" uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus" log "github.com/inconshreveable/log15"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
@ -209,7 +209,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t { for range t {
err := s.checkpoint() err := s.checkpoint()
if err != nil { if err != nil {
log.Errorln(err) log.Error("finish init params error", log.Ctx{"error": err})
} }
} }
}() }()
@ -262,7 +262,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
func traceTime(start time.Time, name string) { func traceTime(start time.Time, name string) {
elapsed := time.Since(start) elapsed := time.Since(start)
log.Infof("%s took %v", name, elapsed) log.Info("time elapsed", log.Ctx{"name": name, "elapsed": elapsed})
} }
// checkpoint saves checkpoint to disk. // checkpoint saves checkpoint to disk.
@ -270,7 +270,7 @@ func traceTime(start time.Time, name string) {
// checkpoint should be only called after the parameters are // checkpoint should be only called after the parameters are
// initialized. // initialized.
func (s *Service) checkpoint() (err error) { func (s *Service) checkpoint() (err error) {
log.Infoln("Begin save checkpoint.") log.Info("Begin save checkpoint.")
defer traceTime(time.Now(), "save checkpoint") defer traceTime(time.Now(), "save checkpoint")
s.mu.Lock() s.mu.Lock()
@ -297,6 +297,13 @@ func (s *Service) checkpoint() (err error) {
return return
} }
if _, err = os.Stat(s.checkpointPath); os.IsNotExist(err) {
err = os.MkdirAll(s.checkpointPath, os.ModePerm)
if err != nil {
return
}
}
id := uuid.NewV4().String() id := uuid.NewV4().String()
p := path.Join(s.checkpointPath, id) p := path.Join(s.checkpointPath, id)
f, err := os.Create(p) f, err := os.Create(p)
@ -308,7 +315,7 @@ func (s *Service) checkpoint() (err error) {
closeErr := f.Close() closeErr := f.Close()
if closeErr != nil { if closeErr != nil {
if err != nil { if err != nil {
log.Errorln(closeErr) log.Error("error close checkpoint file", log.Ctx{"error": closeErr})
} else { } else {
// Set closeErr as return value. // Set closeErr as return value.
err = closeErr err = closeErr
@ -329,7 +336,7 @@ func (s *Service) checkpoint() (err error) {
oldMeta, err := loadMeta(s.client, s.idx) oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound { if err == ErrCheckpointNotFound {
log.Infoln("Do not have existing checkpoint.") log.Info("Do not have existing checkpoint.")
err = nil err = nil
} }
@ -361,7 +368,7 @@ func (s *Service) checkpoint() (err error) {
if rmErr != nil { if rmErr != nil {
// log error, but still treat checkpoint as // log error, but still treat checkpoint as
// successful. // successful.
log.Errorln(rmErr) log.Error("remove old meta file error", log.Ctx{"error": rmErr})
} }
} }

@ -1,4 +1,7 @@
# ddim lib # ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(saver_proto SRCS framework.proto saver.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
@ -7,8 +10,8 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor saver_proto framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
@ -16,7 +19,6 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc) cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)

@ -115,6 +115,7 @@ message VarDesc {
SELECTED_ROWS = 2; SELECTED_ROWS = 2;
FEED_MINIBATCH = 3; FEED_MINIBATCH = 3;
FETCH_LIST = 4; FETCH_LIST = 4;
STEP_SCOPES = 5;
} }
required string name = 1; required string name = 1;
required VarType type = 2; required VarType type = 2;

@ -13,6 +13,15 @@
limitations under the License. */ limitations under the License. */
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/saver.pb.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <iterator>
#include <glog/logging.h> #include <glog/logging.h>
@ -112,5 +121,140 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
lod_ = new_lod; lod_ = new_lod;
} }
std::string LoDTensor::SerializeToString() const {
LoDTensorProto desc;
// set data_type
if (this->type() == typeid(int8_t)) desc.set_data_type(DataType::BOOL);
if (this->type() == typeid(int16_t)) desc.set_data_type(DataType::INT16);
if (this->type() == typeid(int32_t)) desc.set_data_type(DataType::INT32);
if (this->type() == typeid(int64_t)) desc.set_data_type(DataType::INT64);
// FIXME(dzh): there is no fp16 in standard c++
if (this->type() == typeid(float)) // NOLINT
desc.set_data_type(DataType::FP32);
if (this->type() == typeid(double)) // NOLINT
desc.set_data_type(DataType::FP64);
for (int i = 0; i < dims().size(); ++i) {
desc.add_dims(dims()[i]);
}
// set lod information
desc.set_lod_level(this->NumLevels());
for (size_t i = 0; i < this->NumLevels(); ++i) {
LoDInfo* lod = desc.add_levels();
for (size_t j = 0; j < lod_[i].size(); ++j) {
lod->add_level(lod_[i][j]);
}
}
desc.set_version(0);
std::string desc_bytes = desc.SerializeAsString();
// FIXME(dzh) : implement fix chunk size buffer.
size_t DESC_SIZE = desc_bytes.size();
size_t DATA_SIZE = holder_->size() - offset_;
const size_t BUFFER_SIZE = DESC_SIZE + DATA_SIZE + 2 * sizeof(size_t);
char* buffer =
static_cast<char*>(memory::Alloc(platform::CPUPlace(), BUFFER_SIZE));
// format: desc_size data_size, desc_bytes, data_bytes.
platform::CPUPlace src_place;
platform::CPUPlace dst_place;
memory::Copy(dst_place, buffer, src_place, &BUFFER_SIZE, sizeof(size_t));
memory::Copy(dst_place, buffer + sizeof(size_t), src_place, &DESC_SIZE,
sizeof(size_t));
memory::Copy(dst_place, buffer + sizeof(size_t) * 2, src_place,
desc_bytes.c_str(), desc_bytes.size());
PADDLE_ENFORCE(this->numel() != 0, "Serialize a empty Tensor!");
platform::Place place = holder_->place();
int element_width = holder_->size() / this->numel();
if (platform::is_cpu_place(place)) {
memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
boost::get<platform::CPUPlace>(place),
static_cast<char*>(holder_->ptr()) + offset_ / element_width,
DATA_SIZE);
}
#ifdef PADDLE_WITH_GPU
if (platform::is_gpu_place(place)) {
memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
boost::get<platform::GPUPlace>(place),
static_cast<char*>(holder_->ptr()) + offset_ / element_width,
DATA_SIZE);
}
#endif
std::string ret(buffer, BUFFER_SIZE);
memory::Free(platform::CPUPlace(), buffer);
return ret;
}
void LoDTensor::DeserializeFromString(const std::string& s,
const platform::Place& dst_place) {
size_t DESC_SIZE, BUFFER_SIZE;
platform::CPUPlace src_place;
memory::Copy(src_place, &BUFFER_SIZE, src_place, s.c_str(), sizeof(size_t));
memory::Copy(src_place, &DESC_SIZE, src_place, s.c_str() + sizeof(size_t),
sizeof(size_t));
const size_t DATA_SIZE = BUFFER_SIZE - DESC_SIZE - sizeof(size_t) * 2;
// parse LoDTensorDesc
LoDTensorProto desc;
desc.ParseFromArray(s.c_str() + sizeof(size_t) * 2, DESC_SIZE);
std::vector<int64_t> dims;
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
this->Resize(make_ddim(dims));
// parse data type
void* ptr = nullptr;
if (desc.data_type() == DataType::BOOL)
ptr = this->mutable_data<bool>(dst_place);
if (desc.data_type() == DataType::INT16)
ptr = this->mutable_data<int16_t>(dst_place);
if (desc.data_type() == DataType::INT32)
ptr = this->mutable_data<int32_t>(dst_place);
if (desc.data_type() == DataType::INT64)
ptr = this->mutable_data<int64_t>(dst_place);
// FIXME(dzh): there is no fp16 in standard c++
if (desc.data_type() == DataType::FP32)
ptr = this->mutable_data<float>(dst_place);
if (desc.data_type() == DataType::FP64)
ptr = this->mutable_data<double>(dst_place);
LoD lod;
std::vector<size_t> levels;
for (int i = 0; i < desc.levels().size(); ++i) {
auto current_level = desc.levels()[i].level();
std::copy(current_level.begin(), current_level.end(),
std::back_inserter(levels));
lod.emplace_back(levels);
levels.clear();
}
this->set_lod(lod);
if (platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), ptr, src_place,
s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
}
#ifdef PADDLE_WITH_GPU
if (platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), ptr, src_place,
s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
}
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save