diff --git a/doc/design/model_format.md b/doc/design/model_format.md
index 754bb398e0..e29129fddf 100644
--- a/doc/design/model_format.md
+++ b/doc/design/model_format.md
@@ -12,27 +12,25 @@ The topology is saved as a plain text in a detailed self-contain protobuf file.
The parameters are saved as a binary file. As we all know, the protobuf message has a limit of [64M size](https://developers.google.com/protocol-buffers/docs/reference/cpp/google.protobuf.io.coded_stream#CodedInputStream.SetTotalBytesLimit.details). We have done a [benchmark experiment](https://github.com/PaddlePaddle/Paddle/pull/4610), which shows that protobuf is not fit for the task.
-As a result, we design a particular format for tensor serialization. By default, an arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of [LoDTensorDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99). We save the DescProto as the byte string header. It contains all the necessary information, such as the `dims`, the `name` of the tensor, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). A tensor stores values in a continuous memory buffer. For speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
-
-|HeaderLength|ContentLength|**LoDTensorDesc**|**TensorValue**|
+As a result, we design a particular format for tensor serialization. By default, an arbitrary tensor in Paddle is a [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md), and has a description information proto of [LoDTensorDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L99). We save the DescProto as the byte string header. It contains all the necessary information, such as the `dims`, and the `LoD` information in [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/1c0a4c901c9fc881d120249c703b15d1c50dae7d/paddle/framework/lod_tensor.md). A tensor stores values in a continuous memory buffer. For speed we dump the raw memory to disk and save it as the byte string content. So, the binary format of one tensor is,
The table below shows a tensor's byte view in detail. Note that all the signed values are written in the little-endian format.
-```text
-[offset] [type] [description]
-0004 4 bytes integer HeaderLength, the length of LoDTensorDesc
-0008 4 bytes integer ContentLength, the length of LodTensor Buffer
-0009 1 bytes char TensorDesc
-00010 1 bytes char TensorDesc
-...
-00100 1 bytes char TensorValue
-00101 1 bytes char TensorValue
-00102 1 bytes char TensorValue ..
-...
-```
+|field name | type | description |
+| --- | --- | --- |
+| version | uint32_t | Version of saved file. Always 0 now. |
+| tensor desc length | uint32_t | TensorDesc(Protobuf message) length in bytes. |
+| tensor desc | void* | TensorDesc protobuf binary message |
+| tensor data | void* | Tensor's data in binary format. The length of `tensor_data` is decided by `TensorDesc.dims()` and `TensorDesc.data_type()` |
+| lod_level | uint64_t | Level of LoD |
+| length of lod[0] | uint64_t | [Optional] length of lod[0] in bytes. |
+| data of lod[0] | uint64_t* | [Optional] lod[0].data() |
+| ... | ... | ... |
+
+
## Summary
- We introduce a model format.
-- The `ProgramDesc` describe the model **topology**.
+- The model represented by its forward-pass computation procedure is saved in a **ProgramDesc** protobuf message.
- A bunch of specified format binary tensors describe the **parameters**.
diff --git a/doc/design/regularization.md b/doc/design/regularization.md
index 703a9fbdd4..21280ac898 100644
--- a/doc/design/regularization.md
+++ b/doc/design/regularization.md
@@ -1,7 +1,7 @@
# Regularization in PaddlePaddle
## Introduction to Regularization
-A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. Many strategies are used by machine learning practitioners to reduce the test error, possibly at the expense of increased training error. These strategies are collectively known as **regularization**.
+A central problem in machine learning is how to design an algorithm that will perform well not just on the training data, but also on new data. A frequently faced problem is the problem of **overfitting**, where the model does not make reliable predictions on new unseen data. **Regularization** is the process of introducing additional information in order to prevent overfitting. This is usually done by adding extra penalties to the loss function that restricts the parameter spaces that an optimization algorithm can explore.
### Parameter Norm Penalties
Most common regularization approaches in deep learning are based on limiting the capacity of the models by adding a parameter norm penalty to the objective function `J`. This is given as follows:
@@ -18,52 +18,21 @@ The most commonly used norm penalties are the L2 norm penalty and the L1 norm pe
##### L1 Regularization

-A much more detailed mathematical background of reguilarization can be found [here](http://www.deeplearningbook.org/contents/regularization.html).
+A much more detailed mathematical background of regularization can be found [here](http://www.deeplearningbook.org/contents/regularization.html).
+## Regularization Survey
-## How to do Regularization in PaddlePaddle
-
-On surveying existing frameworks like Tensorflow, PyTorch, Caffe, etc, it can be seen that there are 2 common approaches of doing regularization:
-
-1. Making regularization a part of the optimizer using an attribute like `weight_decay` that is used to control the scale of the L2 Penalty. This approach is used in PyTorch as follows:
- ```python
- opt = torch.optim.SGD(params, lr=0.2, weight_decay=0.2)
- ```
- At every optimization step, this code will add the gradient of the L2 Norm of the params to the gradient of the params with respect to the loss function. This can seen in the following code snippet:
- ```python
- if weight_decay != 0:
- d_p.add_(weight_decay, p.data)
- ```
- This is a very restyrictive way of doing regularization and does not give the users enough flexibility.
-
- **Advantages**:
- - It is easy to implement for us.
- - Faster execution of backward. However, it can be done manually by advanced users too.
-
- **Disadvantages**:
- - Not flexible for other regularizations such as L1/L0 regularization.
- - Does not allow for different regularization coefficient for different parameters. For example, in most models, ony the weight matrices are regularized and the bias vectors are unregularized.
- - Tightly coupled optimizer and regularization implementation.
-
-
-2. Adding regularization ops to the graph through Python API. This approach is used by Tensorflow and Caffe. Using this approach, we manually add regularization ops to the graph and then add the regularization loss to the final loss function before sending them to the optimizer.
-
- **Advantages**:
- - Allows for greater flexibility to the users of Paddle. Using this approach, the users can put different regularization to different parameters and also choose parameters that are not a part of regularization.
- - Makes it easy for the users to customize and extend the framework.
-
- **Disadvantages**:
- - Implementation requires comprehensive design and time.
+A detailed survey of regularization in various deep learning frameworks can be found [here](https://github.com/PaddlePaddle/Paddle/wiki/Regularization-Survey).
## Proposal for Regularization in PaddlePaddle
### Low-Level implementation
-In the new design, we propose to create new operations for regularization. For now, we can add 2 ops thgat correspond to the most frequently used regularizations:
+In the new design, we propose to create new operations for regularization. For now, we can add 2 ops that correspond to the most frequently used regularizations:
- L2_regularization_op
- L1_regularization_op
-These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate Cpu and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes. other than L1 and L2 norm penalties.
+These ops can be like any other ops with their own CPU/GPU implementations either using Eigen or separate CPU and GPU kernels. As the initial implementation, we can implement their kernels using Eigen following the abstraction pattern implemented for [Activation Ops](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/accuracy_op.h). This abstraction pattern can make it very easy to implement new regularization schemes other than L1 and L2 norm penalties.
The idea of building ops for regularization is in sync with the refactored Paddle philosophy of using operators to represent any computation unit. The way these ops will be added to the computation graph, will be decided by the [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) in Python API.
@@ -94,7 +63,7 @@ Since we want to create the regularization ops in a lazy manner, the regularizat
#### High-level API
-In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we lso need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers).
+In PaddlePaddle Python API, users will primarily rely on [layer functions](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md#layer-function) to create neural network layers. Hence, we also need to provide regularization functionality in layer functions. The design of these APIs can be postponed for later right now. A good reference for these APIs can be found in [Keras](https://keras.io/regularizers/) and also by looking at Tensorflow in [`tf.contrib.layers`](https://www.tensorflow.org/api_guides/python/contrib.layers).
diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go
index 90f9cf3fcf..1358801c1c 100644
--- a/go/cmd/pserver/pserver.go
+++ b/go/cmd/pserver/pserver.go
@@ -67,7 +67,7 @@ func main() {
cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil {
if err == pserver.ErrCheckpointNotFound {
- log.Info("Could not find the pserver checkpoint.")
+ log.Info("load checkpoint error", "error", err)
} else {
panic(err)
}
@@ -99,7 +99,7 @@ func main() {
candy.Must(err)
go func() {
- log.Info("starting pserver", log.Ctx{"port": *port})
+ log.Info("serving pserver", log.Ctx{"port": *port})
err = http.Serve(l, nil)
candy.Must(err)
}()
diff --git a/go/master/c/client.go b/go/master/c/client.go
index 9a59337108..9a3960d59c 100644
--- a/go/master/c/client.go
+++ b/go/master/c/client.go
@@ -123,7 +123,8 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
}
err := c.SetDataset(paths)
if err != nil {
- log.Error("error set dataset", log.Ctx{"error": err})
+ log.Error("error set dataset",
+ log.Ctx{"error": err, "paths": paths})
return C.PADDLE_MASTER_ERROR
}
diff --git a/go/master/client.go b/go/master/client.go
index 5d657548c9..7bcf869553 100644
--- a/go/master/client.go
+++ b/go/master/client.go
@@ -121,6 +121,7 @@ func (c *Client) StartGetRecords(passID int) {
}
func (c *Client) getRecords(passID int) {
+ i := 0
for {
t, err := c.getTask(passID)
if err != nil {
@@ -130,12 +131,20 @@ func (c *Client) getRecords(passID int) {
c.ch <- record{nil, err}
break
}
- if err.Error() == ErrPassAfter.Error() {
- // wait util last pass finishes
- time.Sleep(time.Second * 3)
- continue
+
+ if i%60 == 0 {
+ log.Debug("getTask of passID error.",
+ log.Ctx{"error": err, "passID": passID})
+ i = 0
}
- log.Error("getTask error.", log.Ctx{"error": err})
+
+ // if err.Error() == ErrPassAfter.Error()
+ // wait util last pass finishes
+ // if other error such as network error
+ // wait to reconnect or task time out
+ time.Sleep(time.Second * 3)
+ i += 3
+ continue
}
for _, chunk := range t.Chunks {
diff --git a/go/master/client_test.go b/go/master/client_test.go
index 79b9cc844d..1963dbfd73 100644
--- a/go/master/client_test.go
+++ b/go/master/client_test.go
@@ -117,6 +117,7 @@ func TestNextRecord(t *testing.T) {
if e != nil {
panic(e)
}
+
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go
index e04c86de0a..6d28cad25a 100644
--- a/go/pserver/optimizer.go
+++ b/go/pserver/optimizer.go
@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0])
}
+ var cptr (*C.uchar)
+ if len(c) > 0 {
+ cptr = (*C.uchar)(&c[0])
+ } else {
+ log.Error("empty config", "param name", paramWithConfigs.Param.Name)
+ }
o.config = c
o.opt = C.paddle_create_optimizer(
- (*C.uchar)(&c[0]),
+ cptr,
C.int(len(c)),
C.paddle_element_type(p.ElementType),
cbuffer,
diff --git a/go/pserver/service.go b/go/pserver/service.go
index 6f66faaf27..f703d99a29 100644
--- a/go/pserver/service.go
+++ b/go/pserver/service.go
@@ -17,12 +17,11 @@ package pserver
import (
"bufio"
"bytes"
- "crypto/md5"
"encoding/gob"
- "encoding/hex"
"encoding/json"
"errors"
"fmt"
+ "hash/crc32"
"io/ioutil"
"os"
"path"
@@ -40,7 +39,7 @@ type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
-var ErrCheckpointNotFound = errors.New("checkpoint not found")
+var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
// RPC error message.
const (
@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type checkpointMeta struct {
UUID string `json:"uuid"`
Path string `json:"path"`
- MD5 string `json:"md5"`
+ CRC32 uint32 `json:"crc32"`
Timestamp int64 `json:"timestamp"`
}
@@ -92,7 +91,7 @@ type Service struct {
idx int
checkpointInterval time.Duration
checkpointPath string
- client *EtcdClient
+ client KVStore
mu sync.Mutex
optMap map[string]*optimizer
@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State []byte
}
-func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
+type KVStore interface {
+ GetKey(key string, timeout time.Duration) ([]byte, error)
+ PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
+}
+
+func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
if err != nil {
return
@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}
// LoadCheckpoint loads checkpoint from file.
-func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
+func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
log.Info("Loading checkpoint", "pserver index", idx)
defer traceTime(time.Now(), "load checkpoint")
@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return nil, err
}
- // TODO(helin): change MD5 to CRC since CRC is better for file
- // checksum in our use case (emphasize speed over security).
- h := md5.New()
- md5 := hex.EncodeToString(h.Sum(content))
- if md5 != cpMeta.MD5 {
+ crc32 := crc32.ChecksumIEEE(content)
+ if crc32 != cpMeta.CRC32 {
return nil, errors.New(WrongChecksum)
}
@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if err = dec.Decode(&cp); err != nil {
return nil, err
}
+
return cp, nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
-func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
+func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
s := &Service{
idx: idx,
checkpointInterval: interval,
@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
s.optMap[p.Param.Name] = newOptimizer(p, item.State)
}
+ close(s.initialized)
}
return s, nil
}
@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for range t {
err := s.checkpoint()
if err != nil {
- log.Error("finish init params error", log.Ctx{"error": err})
+ log.Error("checkpoint error", log.Ctx{"error": err})
}
}
}()
@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter.Name = name
parameter.ElementType = opt.elementType
parameter.Content = opt.GetWeights()
+
log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
return nil
}
@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
- log.Info("Do not have existing checkpoint.")
+ log.Info("old meta not found, skip removing old meta")
err = nil
+ } else if err == nil {
+ log.Info("removing old meta")
+ if oldMeta.Path != "" {
+ rmErr := os.Remove(oldMeta.Path)
+ if rmErr != nil {
+ // log error, but still treat checkpoint as
+ // successful.
+ log.Error("remove old meta file error", log.Ctx{"error": rmErr})
+ }
+ }
}
if err != nil {
return
}
- h := md5.New()
- md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
+ crc32 := crc32.ChecksumIEEE(buf.Bytes())
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
- MD5: md5,
+ CRC32: crc32,
Path: p,
}
@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
}
- if oldMeta.Path != "" {
- rmErr := os.Remove(oldMeta.Path)
- if rmErr != nil {
- // log error, but still treat checkpoint as
- // successful.
- log.Error("remove old meta file error", log.Ctx{"error": rmErr})
- }
- }
-
return
}
diff --git a/go/pserver/service_internal_test.go b/go/pserver/service_internal_test.go
new file mode 100644
index 0000000000..36eca5112b
--- /dev/null
+++ b/go/pserver/service_internal_test.go
@@ -0,0 +1,86 @@
+package pserver
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+)
+
+const testDir = "./test_data"
+
+type myKV struct {
+ m map[string][]byte
+}
+
+func (m *myKV) GetKey(key string, timeout time.Duration) ([]byte, error) {
+ if m.m == nil {
+ m.m = make(map[string][]byte)
+ }
+ return m.m[key], nil
+}
+
+func (m *myKV) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
+ if m.m == nil {
+ m.m = make(map[string][]byte)
+ }
+ m.m[key] = value
+ return nil
+}
+
+func TestCheckpoint(t *testing.T) {
+ kv := &myKV{}
+ s, err := NewService(0, time.Hour, testDir, kv, nil)
+ assert.Nil(t, err)
+ err = s.checkpoint()
+ assert.Nil(t, err)
+ _, err = LoadCheckpoint(kv, 0)
+ assert.Nil(t, err)
+}
+
+func float32ToByte(f float32) []byte {
+ var buf bytes.Buffer
+ err := binary.Write(&buf, binary.LittleEndian, f)
+ if err != nil {
+ fmt.Println("binary.Write failed:", err)
+ }
+ return buf.Bytes()
+}
+
+func TestCheckpointWithData(t *testing.T) {
+ kv := &myKV{}
+ s, err := NewService(0, time.Hour, testDir, kv, nil)
+ assert.Nil(t, err)
+
+ var content []byte
+ for i := 0; i < 50000; i++ {
+ content = append(content, float32ToByte(float32(i))...)
+ }
+
+ p1 := Parameter{Name: "p1", ElementType: 1, Content: content}
+ err = s.InitParam(ParameterWithConfig{Param: p1}, nil)
+ assert.Nil(t, err)
+
+ err = s.FinishInitParams(0, nil)
+ assert.Nil(t, err)
+
+ var p2 Parameter
+ err = s.GetParam(p1.Name, &p2)
+ assert.Nil(t, err)
+ assert.Equal(t, p1, p2)
+
+ err = s.checkpoint()
+ assert.Nil(t, err)
+ cp, err := LoadCheckpoint(kv, 0)
+ assert.Nil(t, err)
+ s1, err := NewService(0, time.Hour, testDir, kv, cp)
+ assert.Nil(t, err)
+
+ var p3 Parameter
+ err = s1.GetParam(p1.Name, &p3)
+ assert.Nil(t, err)
+ assert.Equal(t, p1, p3)
+}
diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go
index be648cd1e8..b6f4566eb7 100644
--- a/go/pserver/service_test.go
+++ b/go/pserver/service_test.go
@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Wait()
}
-
-func TestCheckpointSpeed(t *testing.T) {
- //TODO(zhihong): test speed
-}
diff --git a/paddle/capi/gradient_machine.cpp b/paddle/capi/gradient_machine.cpp
index 629449bbd4..482b51e8a8 100644
--- a/paddle/capi/gradient_machine.cpp
+++ b/paddle/capi/gradient_machine.cpp
@@ -64,12 +64,18 @@ paddle_error paddle_gradient_machine_create_for_inference_with_parameters(
modelConfigProtobuf.resize(modelConfigSize);
is.read(&modelConfigProtobuf[0], modelConfigSize);
paddle::TrainerConfig config;
+ paddle::ModelConfig modelConfig;
if (!config.ParseFromString(modelConfigProtobuf) || !config.IsInitialized()) {
- return kPD_PROTOBUF_ERROR;
+ if (!modelConfig.ParseFromString(modelConfigProtobuf) ||
+ !modelConfig.IsInitialized()) {
+ return kPD_PROTOBUF_ERROR;
+ }
+ } else {
+ modelConfig = config.model_config();
}
auto ptr = new paddle::capi::CGradientMachine();
ptr->machine.reset(paddle::GradientMachine::create(
- config.model_config(), CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
+ modelConfig, CREATE_MODE_TESTING, {paddle::PARAMETER_VALUE}));
std::vector& parameters = ptr->machine->getParameters();
for (auto& para : parameters) {
para->load(is);
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index 85374a476d..0d1617424e 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -1,6 +1,5 @@
# ddim lib
proto_library(framework_proto SRCS framework.proto)
-proto_library(saver_proto SRCS framework.proto saver.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
@@ -10,7 +9,7 @@ cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
-cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor saver_proto framework_proto)
+cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
@@ -27,7 +26,7 @@ cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
-cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator)
+cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info operator glog)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
@@ -43,7 +42,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
cc_library(backward SRCS backward.cc DEPS net_op)
-cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
+cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context fill_constant_op)
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward glog)
diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc
index 1ae7fb60f0..150c152367 100644
--- a/paddle/framework/backward.cc
+++ b/paddle/framework/backward.cc
@@ -315,6 +315,7 @@ static void CreateGradVarInBlock(
return false; /* not break */
});
if (need_infer_shape) {
+ ops[op_index]->InferVarType(block_desc);
ops[op_index]->InferShape(*block_desc);
}
}
@@ -452,11 +453,16 @@ ParamGradInfoMap AppendBackward(
std::transform(target_shape_desc.begin(), target_shape_desc.end(),
std::back_inserter(target_shape),
[](int64_t dim) { return static_cast(dim); });
+ VLOG(3) << "backward from loss=" << target.Name()
+ << " data_type=" << target.GetDataType();
std::unique_ptr fill_one_op(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", target_shape},
{"value", static_cast(1.0)},
- {"data_type", framework::DataType::FP32}}));
+ {"data_type", target.GetDataType()}}));
+ // infer var type of fill_one_op
+ fill_one_op->InferVarType(root_block);
+
root_block->AppendAllocatedOp(std::move(fill_one_op));
size_t forward_op_num = root_block->OpSize();
size_t forward_block_num = program_desc.Size();
@@ -475,8 +481,7 @@ ParamGradInfoMap AppendBackward(
std::unordered_map retv;
auto var = root_block->Var(fill_one_op_out);
- // FIXME(qiao) infer the data type
- var->SetDataType(framework::DataType::FP32);
+ var->SetDataType(target.GetDataType());
var->SetShape(target.Shape());
auto& target_grad = retv[target.Name()];
target_grad.name_ = fill_one_op_out;
diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc
index 10301f7e39..421f132194 100644
--- a/paddle/framework/backward_test.cc
+++ b/paddle/framework/backward_test.cc
@@ -21,6 +21,8 @@
#include "paddle/framework/var_desc.h"
#include "paddle/operators/net_op.h"
+USE_OP(fill_constant);
+
namespace paddle {
namespace framework {
diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc
index 251e340e6d..b73a20cc89 100644
--- a/paddle/framework/block_desc.cc
+++ b/paddle/framework/block_desc.cc
@@ -120,6 +120,17 @@ BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}
+
+BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
+ : prog_(prog), desc_(desc), need_update_(false) {
+ for (const VarDesc &var_desc : desc_->vars()) {
+ vars_[var_desc.name()].reset(new VarDescBind(var_desc));
+ }
+ for (const OpDesc &op_desc : desc_->ops()) {
+ ops_.emplace_back(new OpDescBind(op_desc, prog));
+ }
+}
+
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog)
: prog_(prog), desc_(desc) {
diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h
index c685050850..72f77a88a2 100644
--- a/paddle/framework/block_desc.h
+++ b/paddle/framework/block_desc.h
@@ -36,8 +36,7 @@ class ProgramDescBind;
class BlockDescBind {
public:
- BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
- : prog_(prog), desc_(desc), need_update_(false) {}
+ BlockDescBind(ProgramDescBind *prog, BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog);
diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h
index c25a62c2b1..bafb4fbd48 100644
--- a/paddle/framework/data_type.h
+++ b/paddle/framework/data_type.h
@@ -15,6 +15,7 @@
#pragma once
#include
#include "paddle/framework/framework.pb.h"
+#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index a335786753..239ae5e123 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -195,6 +195,14 @@ std::vector vectorize(const DDim& ddim) {
return result;
}
+// NOTE: framework::vectorize converts to type int64_t
+// which does not fit cudnn inputs.
+std::vector vectorize2int(const DDim& ddim) {
+ std::vector temp = vectorize(ddim);
+ std::vector result(temp.begin(), temp.end());
+ return result;
+}
+
struct ProductVisitor : public boost::static_visitor {
template
int64_t operator()(const Dim& dim) {
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index 4a871bb0a9..2a5e2d2b69 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -93,6 +93,7 @@ int64_t get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val);
std::vector vectorize(const DDim& ddim);
+std::vector vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim);
diff --git a/paddle/framework/details/op_registry.h b/paddle/framework/details/op_registry.h
index 357ad21f39..b731840ef2 100644
--- a/paddle/framework/details/op_registry.h
+++ b/paddle/framework/details/op_registry.h
@@ -28,7 +28,8 @@ enum OpInfoFillType {
kOperator = 0,
kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2,
- kVarTypeInference = 3
+ kVarTypeInference = 3,
+ kShapeInference = 4
};
template
@@ -42,7 +43,10 @@ struct OpInfoFillTypeID {
? kGradOpDescMaker
: (std::is_base_of::value
? kVarTypeInference
- : static_cast(-1))));
+ : (std::is_base_of::value
+ ? kShapeInference
+ : static_cast(
+ -1)))));
}
};
@@ -121,6 +125,16 @@ struct OpInfoFiller {
}
};
+template
+struct OpInfoFiller {
+ void operator()(const char* op_type, OpInfo* info) const {
+ info->infer_shape_ = [](InferShapeContext* ctx) {
+ T inference;
+ inference(ctx);
+ };
+ }
+};
+
} // namespace details
} // namespace framework
diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc
index 1f1e4edda8..3e9d8b3084 100644
--- a/paddle/framework/executor.cc
+++ b/paddle/framework/executor.cc
@@ -20,6 +20,7 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h"
@@ -56,6 +57,22 @@ Executor::~Executor() {
}
}
+static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
+ if (var_type == VarDesc::LOD_TENSOR) {
+ var->GetMutable();
+ } else if (var_type == VarDesc::SELECTED_ROWS) {
+ var->GetMutable();
+ } else if (var_type == VarDesc::FEED_MINIBATCH) {
+ var->GetMutable();
+ } else if (var_type == VarDesc::FETCH_LIST) {
+ var->GetMutable();
+ } else {
+ PADDLE_THROW(
+ "Variable type must be "
+ "LoDTensor/SelectedRows/FEED_MINIBATCH/FETCH_LIST.");
+ }
+}
+
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
@@ -69,10 +86,12 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
for (auto& var : block.vars()) {
if (var.persistable()) {
auto* ptr = scope->Var(var.name());
+ CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope.Var(var.name());
+ CreateTensor(ptr, var.type());
VLOG(3) << "Create Variable " << var.name()
<< " locally, which pointer is " << ptr;
}
diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc
index 731235cd98..584308a538 100644
--- a/paddle/framework/lod_tensor.cc
+++ b/paddle/framework/lod_tensor.cc
@@ -13,7 +13,6 @@
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
-#include "paddle/framework/saver.pb.h"
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
@@ -136,141 +135,5 @@ void LoDTensor::ShrinkInLevel(size_t level, size_t elem_begin,
PADDLE_ENFORCE_LT(begin, end, "Cannot shrink, the result tensor is empty.");
ShareDataWith(Slice(begin, end));
}
-
-std::string LoDTensor::SerializeToString() const {
- LoDTensorProto desc;
-
- // set data_type
- if (this->type() == typeid(int8_t)) desc.set_data_type(DataType::BOOL);
- if (this->type() == typeid(int16_t)) desc.set_data_type(DataType::INT16);
- if (this->type() == typeid(int32_t)) desc.set_data_type(DataType::INT32);
- if (this->type() == typeid(int64_t)) desc.set_data_type(DataType::INT64);
- // FIXME(dzh): there is no fp16 in standard c++
-
- if (this->type() == typeid(float)) // NOLINT
- desc.set_data_type(DataType::FP32);
- if (this->type() == typeid(double)) // NOLINT
- desc.set_data_type(DataType::FP64);
-
- for (int i = 0; i < dims().size(); ++i) {
- desc.add_dims(dims()[i]);
- }
-
- // set lod information
- desc.set_lod_level(this->NumLevels());
- for (size_t i = 0; i < this->NumLevels(); ++i) {
- LoDInfo* lod = desc.add_levels();
- for (size_t j = 0; j < lod_[i].size(); ++j) {
- lod->add_level(lod_[i][j]);
- }
- }
-
- desc.set_version(0);
-
- std::string desc_bytes = desc.SerializeAsString();
-
- // FIXME(dzh) : implement fix chunk size buffer.
- size_t DESC_SIZE = desc_bytes.size();
- size_t DATA_SIZE = holder_->size() - offset_;
-
- const size_t BUFFER_SIZE = DESC_SIZE + DATA_SIZE + 2 * sizeof(size_t);
- char* buffer =
- static_cast(memory::Alloc(platform::CPUPlace(), BUFFER_SIZE));
-
- // format: desc_size data_size, desc_bytes, data_bytes.
- platform::CPUPlace src_place;
- platform::CPUPlace dst_place;
-
- memory::Copy(dst_place, buffer, src_place, &BUFFER_SIZE, sizeof(size_t));
- memory::Copy(dst_place, buffer + sizeof(size_t), src_place, &DESC_SIZE,
- sizeof(size_t));
- memory::Copy(dst_place, buffer + sizeof(size_t) * 2, src_place,
- desc_bytes.c_str(), desc_bytes.size());
-
- PADDLE_ENFORCE(this->numel() != 0, "Serialize a empty Tensor!");
-
- platform::Place place = holder_->place();
- int element_width = holder_->size() / this->numel();
-
- if (platform::is_cpu_place(place)) {
- memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
- boost::get(place),
- static_cast(holder_->ptr()) + offset_ / element_width,
- DATA_SIZE);
- }
-#ifdef PADDLE_WITH_GPU
- if (platform::is_gpu_place(place)) {
- memory::Copy(dst_place, buffer + sizeof(size_t) * 2 + desc_bytes.size(),
- boost::get(place),
- static_cast(holder_->ptr()) + offset_ / element_width,
- DATA_SIZE);
- }
-#endif
-
- std::string ret(buffer, BUFFER_SIZE);
- memory::Free(platform::CPUPlace(), buffer);
- return ret;
-}
-
-void LoDTensor::DeserializeFromString(const std::string& s,
- const platform::Place& dst_place) {
- size_t DESC_SIZE, BUFFER_SIZE;
- platform::CPUPlace src_place;
-
- memory::Copy(src_place, &BUFFER_SIZE, src_place, s.c_str(), sizeof(size_t));
- memory::Copy(src_place, &DESC_SIZE, src_place, s.c_str() + sizeof(size_t),
- sizeof(size_t));
-
- const size_t DATA_SIZE = BUFFER_SIZE - DESC_SIZE - sizeof(size_t) * 2;
-
- // parse LoDTensorDesc
- LoDTensorProto desc;
- desc.ParseFromArray(s.c_str() + sizeof(size_t) * 2, DESC_SIZE);
-
- std::vector dims;
- std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
- this->Resize(make_ddim(dims));
-
- // parse data type
- void* ptr = nullptr;
- if (desc.data_type() == DataType::BOOL)
- ptr = this->mutable_data(dst_place);
- if (desc.data_type() == DataType::INT16)
- ptr = this->mutable_data(dst_place);
- if (desc.data_type() == DataType::INT32)
- ptr = this->mutable_data(dst_place);
- if (desc.data_type() == DataType::INT64)
- ptr = this->mutable_data(dst_place);
- // FIXME(dzh): there is no fp16 in standard c++
-
- if (desc.data_type() == DataType::FP32)
- ptr = this->mutable_data(dst_place);
- if (desc.data_type() == DataType::FP64)
- ptr = this->mutable_data(dst_place);
-
- LoD lod;
- std::vector levels;
- for (int i = 0; i < desc.levels().size(); ++i) {
- auto current_level = desc.levels()[i].level();
- std::copy(current_level.begin(), current_level.end(),
- std::back_inserter(levels));
- lod.emplace_back(levels);
- levels.clear();
- }
-
- this->set_lod(lod);
-
- if (platform::is_cpu_place(dst_place)) {
- memory::Copy(boost::get(dst_place), ptr, src_place,
- s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
- }
-#ifdef PADDLE_WITH_GPU
- if (platform::is_gpu_place(dst_place)) {
- memory::Copy(boost::get(dst_place), ptr, src_place,
- s.c_str() + sizeof(size_t) * 2 + DESC_SIZE, DATA_SIZE);
- }
-#endif
-}
-
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h
index 735d85f750..f4fe4cdac6 100644
--- a/paddle/framework/lod_tensor.h
+++ b/paddle/framework/lod_tensor.h
@@ -85,7 +85,9 @@ class LoDTensor : public Tensor {
void set_lod(const LoD& lod) { lod_ = lod; }
- LoD lod() const { return lod_; }
+ const LoD& lod() const { return lod_; }
+
+ LoD* mutable_lod() { return &lod_; }
/*
* Get the start offset and end offset of an element from LoD.
@@ -139,27 +141,6 @@ class LoDTensor : public Tensor {
*/
void ShrinkInLevel(size_t level, size_t elem_begin, size_t elem_end);
- /**
- * @brief Serialize tensor to char bytes.
- * Please check model_format.md for the format detail.
- * NOTE: GPUTensor will copy data to cpu implicitly.
- * @return return string
- */
-
- // FIXME(dzh) : Currently, this interface should only be used in
- // save/restore model and checkpoint. ParameterServer do not use shape
- // information to do the optimization, as a result, when we serialize
- // parameter/gradient to string, we should serialize the tensor
- // to string in the ps trainer instead of LoDTensor.
- std::string SerializeToString() const;
-
- /**
- * @brief Deserialize char bytes to tensor.
- * @return return string
- */
- void DeserializeFromString(const std::string& s,
- const platform::Place& dst_place);
-
private:
LoD lod_;
};
diff --git a/paddle/framework/lod_tensor_test.cc b/paddle/framework/lod_tensor_test.cc
index f309376c8b..aa2f6c993d 100644
--- a/paddle/framework/lod_tensor_test.cc
+++ b/paddle/framework/lod_tensor_test.cc
@@ -144,21 +144,5 @@ TEST(LodExpand, test) {
}
}
-TEST_F(LoDTensorTester, SerializeDeserialize) {
- LoDTensor new_lod_tensor = lod_tensor_;
- float* src_ptr = lod_tensor_.data();
- std::string s = lod_tensor_.SerializeToString();
- LoDTensor dst;
- dst.DeserializeFromString(s, platform::CPUPlace());
- float* dst_ptr = dst.data();
- for (int i = 0; i < kLodTensorSize; ++i) {
- EXPECT_EQ(dst_ptr[i], src_ptr[i]);
- }
-
- ASSERT_EQ(dst.NumElements(0), 2UL);
- ASSERT_EQ(dst.NumElements(1), 3UL);
- ASSERT_EQ(dst.NumElements(2), 8UL);
-}
-
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu
index 11659be02a..c79c4d0c72 100644
--- a/paddle/framework/lod_tensor_test.cu
+++ b/paddle/framework/lod_tensor_test.cu
@@ -47,31 +47,4 @@ TEST(LoDTensor, LoDInGPU) {
for (size_t i = 0; i < src_lod[0].size(); ++i) {
CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
}
-}
-
-TEST(LoDTensor, SerializeDeserialize) {
- paddle::framework::LoDTensor lod_tensor;
- paddle::platform::GPUPlace place(0);
-
- paddle::framework::LoD src_lod;
- src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14});
-
- lod_tensor.Resize({14, 16});
- lod_tensor.mutable_data(place);
-
- lod_tensor.set_lod(src_lod);
- CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
- CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
-
- test<<<1, 8>>>(src_lod[0].data(), src_lod[0].size());
- cudaDeviceSynchronize();
-
- std::string s = lod_tensor.SerializeToString();
- paddle::framework::LoDTensor dst;
- dst.DeserializeFromString(s, place);
- paddle::framework::LoD dst_lod = dst.lod();
-
- for (size_t i = 0; i < dst_lod[0].size(); ++i) {
- CHECK_EQ(src_lod[0].data()[i], dst_lod[0].data()[i] * 2);
- }
-}
+}
\ No newline at end of file
diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc
index 18fabe481d..133869e7b5 100644
--- a/paddle/framework/op_desc.cc
+++ b/paddle/framework/op_desc.cc
@@ -14,9 +14,13 @@ limitations under the License. */
#include "paddle/framework/op_desc.h"
#include
+#include
#include
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
+#include "paddle/framework/program_desc.h"
+
+#include "glog/logging.h"
namespace paddle {
namespace framework {
@@ -24,16 +28,47 @@ namespace framework {
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs) {
- op_desc_.set_type(type);
+ desc_.set_type(type);
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
need_update_ = true;
}
+OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
+ : desc_(desc), need_update_(false) {
+ // restore inputs_
+ int input_size = desc_.inputs_size();
+ for (int i = 0; i < input_size; ++i) {
+ const OpDesc::Var &var = desc_.inputs(i);
+ std::vector &args = inputs_[var.parameter()];
+ int argu_size = var.arguments_size();
+ args.reserve(argu_size);
+ for (int j = 0; j < argu_size; ++j) {
+ args.push_back(var.arguments(j));
+ }
+ }
+ // restore outputs_
+ int output_size = desc_.outputs_size();
+ for (int i = 0; i < output_size; ++i) {
+ const OpDesc::Var &var = desc_.outputs(i);
+ std::vector &args = outputs_[var.parameter()];
+ int argu_size = var.arguments_size();
+ args.reserve(argu_size);
+ for (int j = 0; j < argu_size; ++j) {
+ args.push_back(var.arguments(j));
+ }
+ }
+ // restore attrs_
+ for (const OpDesc::Attr &attr : desc_.attrs()) {
+ std::string attr_name = attr.name();
+ attrs_[attr_name] = GetAttrValue(attr, prog->Proto());
+ }
+}
+
OpDesc *OpDescBind::Proto() {
Flush();
- return &op_desc_;
+ return &desc_;
}
const std::vector &OpDescBind::Input(
@@ -167,23 +202,23 @@ struct SetAttrDescVisitor : public boost::static_visitor {
void OpDescBind::Flush() {
if (need_update_) {
- this->op_desc_.mutable_inputs()->Clear();
+ this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) {
- auto *input = op_desc_.add_inputs();
+ auto *input = desc_.add_inputs();
input->set_parameter(ipt.first);
VectorToRepeated(ipt.second, input->mutable_arguments());
}
- this->op_desc_.mutable_outputs()->Clear();
+ this->desc_.mutable_outputs()->Clear();
for (auto &opt : outputs_) {
- auto *output = op_desc_.add_outputs();
+ auto *output = desc_.add_outputs();
output->set_parameter(opt.first);
VectorToRepeated(opt.second, output->mutable_arguments());
}
- this->op_desc_.mutable_attrs()->Clear();
+ this->desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) {
- auto *attr_desc = op_desc_.add_attrs();
+ auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast(attr.second.which() - 1));
@@ -195,26 +230,26 @@ void OpDescBind::Flush() {
}
}
-using InferShapeFuncMap =
- std::unordered_map>;
-
-static InferShapeFuncMap &InferShapeFuncs() {
- static InferShapeFuncMap *g_map = nullptr;
- if (g_map == nullptr) {
- g_map = new InferShapeFuncMap();
- auto &info_map = OpInfoMap::Instance();
- // all registered kernels
- for (auto &pair : OperatorWithKernel::AllOpKernels()) {
- auto &info = info_map.Get(pair.first);
- // use empty type here to avoid runtime checks.
+static std::once_flag init_infer_shape_funcs;
+
+static void InitInferShapeFuncs() {
+ std::call_once(init_infer_shape_funcs, [] {
+ auto &map = OpInfoMap::Instance();
+ auto &info_map = *map.mutable_map();
+
+ for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
+ auto op_type = kern_pair.first;
+ auto &op_info = info_map.at(op_type);
auto op =
- static_cast(info.Creator()("", {}, {}, {}));
- g_map->insert(
- {pair.first, [op](InferShapeContext *ctx) { op->InferShape(ctx); }});
+ static_cast(op_info.Creator()("", {}, {}, {}));
+ if (op_info.infer_shape_) { // infer_shape has been registered.
+ continue;
+ }
+ op_info.infer_shape_ = [op](InferShapeContext *ctx) {
+ op->InferShape(ctx);
+ };
}
- }
- return *g_map;
+ });
}
void OpDescBind::CheckAttrs() {
@@ -230,13 +265,13 @@ void OpDescBind::CheckAttrs() {
}
void OpDescBind::InferShape(const BlockDescBind &block) const {
- auto &funcs = InferShapeFuncs();
- auto it = funcs.find(this->Type());
- if (it == funcs.end()) {
- PADDLE_THROW("Operator %s has not been registered", this->Type());
- }
+ VLOG(3) << "CompileTime infer shape on " << Type();
+ InitInferShapeFuncs();
+ auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
+ PADDLE_ENFORCE(static_cast(infer_shape),
+ "%s's infer_shape has not been registered", this->Type());
CompileTimeInferShapeContext ctx(*this, block);
- it->second(&ctx);
+ infer_shape(&ctx);
}
void OpDescBind::InferVarType(BlockDescBind *block) const {
diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h
index 313bf538ac..9b8fe17d6e 100644
--- a/paddle/framework/op_desc.h
+++ b/paddle/framework/op_desc.h
@@ -24,6 +24,7 @@ namespace paddle {
namespace framework {
class BlockDescBind;
+class ProgramDescBind;
class OpDescBind {
public:
@@ -32,11 +33,13 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);
+ OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
+
OpDesc *Proto();
- std::string Type() const { return op_desc_.type(); }
+ std::string Type() const { return desc_.type(); }
- void SetType(const std::string &type) { op_desc_.set_type(type); }
+ void SetType(const std::string &type) { desc_.set_type(type); }
const std::vector &Input(const std::string &name) const;
@@ -117,7 +120,7 @@ class OpDescBind {
return ret_val;
}
- OpDesc op_desc_;
+ OpDesc desc_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h
index 59a64d7137..d3b1a3b5fa 100644
--- a/paddle/framework/op_info.h
+++ b/paddle/framework/op_info.h
@@ -25,12 +25,19 @@
namespace paddle {
namespace framework {
+class InferShapeBase {
+ public:
+ virtual ~InferShapeBase() = default;
+ virtual void operator()(InferShapeContext*) const = 0;
+};
+
struct OpInfo {
OpCreator creator_;
GradOpMakerFN grad_op_maker_;
OpProto* proto_{nullptr};
OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_;
+ InferShapeFN infer_shape_;
bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr;
@@ -87,13 +94,13 @@ class OpInfoMap {
}
}
- const std::unordered_map& map() const {
- return map_;
- }
+ const std::unordered_map& map() const { return map_; }
+
+ std::unordered_map* mutable_map() { return &map_; }
private:
OpInfoMap() = default;
- std::unordered_map map_;
+ std::unordered_map map_;
DISABLE_COPY_AND_ASSIGN(OpInfoMap);
};
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index a67625fa88..db154e4f76 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -33,24 +33,6 @@ ExecutionContext::GetEigenDevice() const {
}
#endif
-const Tensor* GetTensorFromVar(const Variable* var) {
- if (var->IsType()) {
- return &var->Get();
- }
- PADDLE_ENFORCE(var->IsType(),
- "The Input must be LoDTensor or Tensor.");
- return &var->Get();
-}
-
-Tensor* GetTensorFromVar(Variable* var) {
- if (var->IsType()) {
- return var->GetMutable();
- }
- PADDLE_ENFORCE(var->IsType(),
- "The Input must be LoDTensor or Tensor.");
- return var->GetMutable();
-}
-
std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
@@ -204,6 +186,30 @@ void OperatorBase::GenerateTemporaryNames() {
}
}
+static const Tensor* GetTensorFromVar(const Variable* var) {
+ const Tensor* t = nullptr;
+ if (var->IsType()) {
+ t = &(var->Get());
+ } else if (var->IsType()) {
+ t = &(var->Get().value());
+ } else {
+ PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
+ }
+ return t;
+}
+
+static Tensor* GetMutableTensorFromVar(Variable* var) {
+ Tensor* t = nullptr;
+ if (var->IsType()) {
+ t = var->GetMutable();
+ } else if (var->IsType()) {
+ t = var->GetMutable()->mutable_value();
+ } else {
+ PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
+ }
+ return t;
+}
+
template <>
const Tensor* ExecutionContext::Input(const std::string& name) const {
auto* var = InputVar(name);
@@ -227,7 +233,7 @@ const std::vector ExecutionContext::MultiInput(
template <>
Tensor* ExecutionContext::Output(const std::string& name) const {
auto var = OutputVar(name);
- return var == nullptr ? nullptr : var->GetMutable();
+ return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
}
template <>
@@ -240,7 +246,7 @@ std::vector ExecutionContext::MultiOutput(
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
return var == nullptr ? nullptr
- : var->GetMutable();
+ : GetMutableTensorFromVar(var);
});
return res;
}
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 0d0304ac9e..aa79f16df8 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -28,6 +28,7 @@ limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h"
+#include "paddle/framework/selected_rows.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
@@ -60,9 +61,6 @@ inline std::string GradVarName(const std::string& var_name) {
class OperatorBase;
class ExecutionContext;
-extern const Tensor* GetTensorFromVar(const Variable* var);
-extern Tensor* GetTensorFromVar(Variable* var);
-
/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
@@ -414,7 +412,9 @@ class CompileTimeInferShapeContext : public InferShapeContext {
private:
DDim GetDim(const std::string& name) const override {
- return framework::make_ddim(block_.FindVarRecursive(name)->Shape());
+ auto var = block_.FindVarRecursive(name);
+ PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
+ return framework::make_ddim(var->Shape());
}
void SetDim(const std::string& name, const DDim& dim) override {
@@ -511,28 +511,26 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
private:
- template
- Tensor* GetTensor(const std::string& name) const {
- Tensor* t = nullptr;
- auto* var = scope_.FindVar(name);
- if (!var->IsType() && !var->IsType()) {
- if (Allocate) {
- t = var->GetMutable();
- } else {
- PADDLE_THROW("Variable(%s) should be tensor", name);
- }
+ DDim GetDim(const std::string& name) const override {
+ Variable* var = scope_.FindVar(name);
+ if (var->IsType()) {
+ return var->Get().dims();
+ } else if (var->IsType()) {
+ return var->Get().GetCompleteDims();
} else {
- t = GetTensorFromVar(scope_.FindVar(name));
+ PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
}
- return t;
- }
-
- DDim GetDim(const std::string& name) const override {
- return GetTensor(name)->dims();
}
void SetDim(const std::string& name, const DDim& dim) override {
- GetTensor(name)->Resize(dim);
+ Variable* var = scope_.FindVar(name);
+ if (var->IsType()) {
+ var->GetMutable()->Resize(dim);
+ } else if (var->IsType()) {
+ var->GetMutable()->set_height(dim[0]);
+ } else {
+ PADDLE_THROW("Variable type must be LoDTensor/SelectedRows.");
+ }
}
const OperatorBase& op_;
@@ -638,7 +636,9 @@ class OperatorWithKernel : public OperatorBase {
});
}
- virtual void InferShape(InferShapeContext* ctx) const = 0;
+ virtual void InferShape(InferShapeContext* ctx) const {
+ OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
+ }
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
@@ -655,11 +655,14 @@ class OperatorWithKernel : public OperatorBase {
t = &var->Get();
} else if (var->IsType()) {
t = &var->Get();
+ } else if (var->IsType()) {
+ t = &(var->Get().value());
}
if (t != nullptr) {
int tmp = static_cast(ToDataType(t->type()));
+ VLOG(3) << "Input " << ipt_name << " with data_type " << tmp;
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
- "DataType of Paddle Op must be same.");
+ "DataType of Paddle Op %s must be same.", Type());
data_type = tmp;
}
}
diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc
index c358f1a2b6..3c07621293 100644
--- a/paddle/framework/operator_test.cc
+++ b/paddle/framework/operator_test.cc
@@ -237,12 +237,12 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope;
- scope.Var("x0")->GetMutable();
- scope.Var("x1")->GetMutable();
- scope.Var("x2")->GetMutable();
- scope.Var("k0")->GetMutable();
- scope.Var("y0")->GetMutable();
- scope.Var("y1")->GetMutable();
+ scope.Var("x0")->GetMutable();
+ scope.Var("x1")->GetMutable();
+ scope.Var("x2")->GetMutable();
+ scope.Var("k0")->GetMutable();
+ scope.Var("y0")->GetMutable();
+ scope.Var("y1")->GetMutable();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
op->Run(scope, cpu_device_context);
diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc
index 8e99bba811..82f16a7c8b 100644
--- a/paddle/framework/program_desc.cc
+++ b/paddle/framework/program_desc.cc
@@ -19,9 +19,9 @@ namespace paddle {
namespace framework {
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
- auto *b = prog_.add_blocks();
+ auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID());
- b->set_idx(prog_.blocks_size() - 1);
+ b->set_idx(desc_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get();
}
@@ -30,23 +30,32 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Flush();
}
- return &prog_;
+ return &desc_;
}
ProgramDescBind::ProgramDescBind() {
- auto *block = prog_.mutable_blocks()->Add();
+ auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block));
}
ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
- prog_ = o.prog_;
+ desc_ = o.desc_;
- for (int i = 0; i < prog_.blocks_size(); ++i) {
- auto *block = prog_.mutable_blocks(i);
+ for (int i = 0; i < desc_.blocks_size(); ++i) {
+ auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this));
}
}
+
+ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
+ PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
+ "Fail to parse program_desc from binary string.");
+ for (auto &block_desc : *desc_.mutable_blocks()) {
+ blocks_.emplace_back(new BlockDescBind(this, &block_desc));
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h
index dc4cd7cc73..b6e76515a5 100644
--- a/paddle/framework/program_desc.h
+++ b/paddle/framework/program_desc.h
@@ -31,6 +31,8 @@ class ProgramDescBind {
ProgramDescBind(const ProgramDescBind &o);
+ explicit ProgramDescBind(const std::string &binary_str);
+
BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
@@ -40,7 +42,7 @@ class ProgramDescBind {
ProgramDesc *Proto();
private:
- ProgramDesc prog_;
+ ProgramDesc desc_;
std::vector> blocks_;
};
diff --git a/paddle/framework/program_desc_test.cc b/paddle/framework/program_desc_test.cc
index c9709a2d3f..d28c2a0bff 100644
--- a/paddle/framework/program_desc_test.cc
+++ b/paddle/framework/program_desc_test.cc
@@ -59,7 +59,7 @@ TEST(ProgramDesc, copy_ctor) {
};
ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames());
- ASSERT_EQ(3, global_block_copy->LocalVarNames().size());
+ ASSERT_EQ(3UL, global_block_copy->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
@@ -79,5 +79,67 @@ TEST(ProgramDesc, copy_ctor) {
// Not check block's protostr are same it because the order of vars could be
// different and it is correct.
}
+
+TEST(ProgramDescBind, serialize_and_deserialize) {
+ ProgramDescBind program_origin;
+ auto* global_block = program_origin.Block(0);
+ auto* x = global_block->Var("X");
+ x->SetType(VarDesc_VarType_LOD_TENSOR);
+ x->SetLoDLevel(0);
+ x->SetDataType(FP32);
+ x->SetShape({1000, 784});
+
+ auto* y = global_block->Var("Y");
+ y->SetType(VarDesc_VarType_LOD_TENSOR);
+ y->SetLoDLevel(0);
+ y->SetDataType(FP32);
+ y->SetShape({784, 100});
+
+ auto* op = global_block->AppendOp();
+ op->SetType("mul");
+ op->SetInput("X", {x->Name()});
+ op->SetInput("Y", {y->Name()});
+
+ auto* out = global_block->Var("Out");
+ out->SetType(VarDesc_VarType_LOD_TENSOR);
+ op->SetOutput("Y", {out->Name()});
+
+ std::string binary_str;
+ program_origin.Proto()->SerializeToString(&binary_str);
+
+ ProgramDescBind program_restored(binary_str);
+ auto* global_block_restored = program_restored.Block(0);
+ ASSERT_NE(global_block, global_block_restored);
+
+ auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
+ ASSERT_TRUE(global_block_restored->HasVar(name));
+ auto* restored = global_block_restored->Var(name);
+ ASSERT_NE(restored, var_before);
+ ASSERT_EQ(restored->Name(), var_before->Name());
+ ASSERT_EQ(restored->GetType(), var_before->GetType());
+ ASSERT_EQ(restored->Shape(), var_before->Shape());
+ ASSERT_EQ(restored->Proto()->SerializeAsString(),
+ var_before->Proto()->SerializeAsString());
+ };
+
+ ASSERT_EQ(global_block->LocalVarNames(),
+ global_block_restored->LocalVarNames());
+ ASSERT_EQ(3UL, global_block_restored->LocalVarNames().size());
+ assert_same_var("X", x);
+ assert_same_var("Y", y);
+ assert_same_var("Out", out);
+
+ for (size_t i = 0; i < global_block->OpSize(); ++i) {
+ auto op_origin = global_block->Op(i);
+ auto op_restored = global_block->Op(i);
+
+ ASSERT_EQ(op_origin->Type(), op_restored->Type());
+ ASSERT_EQ(op_origin->Inputs(), op_restored->Inputs());
+ ASSERT_EQ(op_origin->Outputs(), op_restored->Outputs());
+
+ ASSERT_EQ(op_restored->Proto()->SerializeAsString(),
+ op_origin->Proto()->SerializeAsString());
+ }
+}
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/saver.proto b/paddle/framework/saver.proto
deleted file mode 100644
index 90a191a6a7..0000000000
--- a/paddle/framework/saver.proto
+++ /dev/null
@@ -1,39 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-syntax = "proto2";
-option optimize_for = LITE_RUNTIME;
-package paddle.framework;
-
-import "framework.proto";
-
-/**
- * This file contains necessary information for model, checkpoint.
- * etc.
- */
-
-message LoDInfo { repeated int64 level = 1; }
-
-/**
- * Save the LoDTensorDesc information through LoDTensorProto, its data memory
- * is copyed to c buffer immediately. See model_format.md for details.
- */
-
-message LoDTensorProto {
- optional DataType data_type = 1;
- repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
- repeated LoDInfo levels = 3;
- optional int32 lod_level = 4 [ default = 0 ];
- optional int32 version = 5;
-}
diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h
index cd90781371..0332b91323 100644
--- a/paddle/framework/selected_rows.h
+++ b/paddle/framework/selected_rows.h
@@ -23,7 +23,10 @@ class SelectedRows {
value_.reset(new Tensor());
}
- SelectedRows() { value_.reset(new Tensor()); }
+ SelectedRows() {
+ height_ = 0;
+ value_.reset(new Tensor());
+ }
platform::Place place() const { return value_->place(); }
@@ -37,6 +40,8 @@ class SelectedRows {
const Vector& rows() const { return rows_; }
+ Vector* mutable_rows() { return &rows_; }
+
void set_rows(const Vector& rows) { rows_ = rows; }
DDim GetCompleteDims() const {
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index e31472327d..9d2dc6a32b 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -132,6 +132,8 @@ class Tensor {
std::type_index type() const { return holder_->type(); }
+ size_t memory_size() const;
+
private:
inline void check_memory_size() const;
diff --git a/paddle/framework/tensor_array.cc b/paddle/framework/tensor_array.cc
index 6f0b84dd1a..0947e33548 100644
--- a/paddle/framework/tensor_array.cc
+++ b/paddle/framework/tensor_array.cc
@@ -254,13 +254,12 @@ LoDTensor TensorArray::LodPackTwo(const LoDTensor& pre, const LoDTensor& cur,
void TensorArray::LodUnpack(const LoDTensor& source, size_t level) {
PADDLE_ENFORCE_EQ(level, source.NumLevels() - 1,
"only the lowest LoD level supports unpack.");
- int non_empty_instances = -1;
+ const size_t non_empty_instances = source.dims()[0];
size_t index = 0;
Vector lowest_lod_level;
lowest_lod_level.push_back(index);
- for (size_t step = 0; non_empty_instances > 0 || non_empty_instances == -1;
- step++) {
+ for (size_t step = 0; step < non_empty_instances; step++) {
size_t num_instances = 0;
for (size_t id = 0; id < source.NumElements(level); id++) {
auto instance = source;
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index f6e801bbb4..29ac683f48 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -62,12 +62,16 @@ inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(
- holder_->size(), numel() * SizeOfType(type()) + offset_,
+ holder_->size(), memory_size() + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored.");
}
+inline size_t Tensor::memory_size() const {
+ return holder_ == nullptr ? 0UL : numel() * SizeOfType(type());
+}
+
template
inline const T* Tensor::data() const {
check_memory_size();
diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h
index 00da728939..c38c4a8ae9 100644
--- a/paddle/framework/type_defs.h
+++ b/paddle/framework/type_defs.h
@@ -28,6 +28,8 @@ class OperatorBase;
class OpDescBind;
class BlockDescBind;
class BlockDesc;
+class InferShapeContext;
+
using VariableNameMap = std::map>;
// The order should be as same as framework.proto
@@ -49,5 +51,7 @@ using GradOpMakerFN = std::function>(
using InferVarTypeFN = std::function;
+using InferShapeFN = std::function;
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h
index 929de1f836..70daa20e8d 100644
--- a/paddle/framework/var_desc.h
+++ b/paddle/framework/var_desc.h
@@ -59,6 +59,8 @@ class VarDescBind {
desc_.set_type(VarDesc::LOD_TENSOR);
}
+ explicit VarDescBind(const VarDesc &desc) : desc_(desc) {}
+
VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); }
diff --git a/paddle/framework/variable.h b/paddle/framework/variable.h
index a80f0e66b5..cde5ec2413 100644
--- a/paddle/framework/variable.h
+++ b/paddle/framework/variable.h
@@ -46,6 +46,8 @@ class Variable {
std::type_index(typeid(T)) == std::type_index(holder_->Type());
}
+ void Clear() { holder_.reset(); }
+
private:
struct Placeholder {
virtual ~Placeholder() {}
diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
new file mode 100644
index 0000000000..9b0ae20f08
--- /dev/null
+++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp
@@ -0,0 +1,309 @@
+/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "MKLDNNBatchNormLayer.h"
+
+using namespace mkldnn; // NOLINT
+typedef memory::format format;
+
+namespace paddle {
+
+REGISTER_LAYER(mkldnn_batch_norm, MKLDNNBatchNormLayer);
+
+const real MKLDNNBatchNormLayer::EPS = 1E-5;
+
+bool MKLDNNBatchNormLayer::init(const LayerMap& layerMap,
+ const ParameterMap& parameterMap) {
+ if (!MKLDNNLayer::init(layerMap, parameterMap)) {
+ return false;
+ }
+
+ // first one is input layer
+ // the other two are created in config_parser.py saving moving mean and var
+ CHECK_EQ(inputLayers_.size(), 3U);
+ CHECK_EQ(inputLayers_.size(), parameters_.size());
+ CHECK_EQ(inputLayers_.size(), size_t(config_.inputs_size()));
+
+ const ImageConfig& conf = config_.inputs(0).image_conf();
+ ic_ = conf.channels();
+ ih_ = inputLayers_[0]->getOutput().getFrameHeight();
+ iw_ = inputLayers_[0]->getOutput().getFrameWidth();
+ if (iw_ == 0 && ih_ == 0) {
+ iw_ = conf.img_size();
+ ih_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
+ }
+ oc_ = ic_;
+ oh_ = ih_;
+ ow_ = iw_;
+ if (config_.has_use_global_stats()) {
+ useGlobalStats_ = config_.use_global_stats();
+ }
+ movingAvgFraction_ = config_.moving_average_fraction();
+ VLOG(MKLDNN_BASE) << "--- " << (useGlobalStats_ ? "use" : "do not use")
+ << " --- global stats";
+ VLOG(MKLDNN_BASE) << "Moving average fraction: " << movingAvgFraction_;
+
+ initWeight();
+ movingMean_.reset(new Weight(oc_, 1, parameters_[1], 0));
+ movingVar_.reset(new Weight(oc_, 1, parameters_[2], 0));
+ return true;
+}
+
+void MKLDNNBatchNormLayer::initWeight() {
+ weight_.reset(new Weight(1, oc_, parameters_[0]));
+ if (biasParameter_.get() != NULL) {
+ biases_ = std::unique_ptr(new Weight(1, oc_, biasParameter_));
+ }
+ CHECK_EQ(weight_ != nullptr, biases_ != nullptr)
+ << "only support have both weight and bias, or neither";
+ if (weight_ && weight_->getW()) {
+ CHECK(biases_ && biases_->getW());
+ valueScaleShift_ = Matrix::create(2, oc_, false, false);
+ valueScaleShift_->zeroMem();
+ VectorPtr scale(new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), 0));
+ VectorPtr shift(
+ new CpuVector(oc_, valueScaleShift_->getMemoryHandle(), oc_));
+ const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_VALUE);
+ const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_VALUE);
+ scale->copyFrom(*wgt);
+ shift->copyFrom(*bias);
+ wgt->setData(valueScaleShift_->getData());
+ bias->setData(valueScaleShift_->getData() + oc_);
+ }
+ if (weight_ && weight_->getWGrad()) {
+ CHECK(biases_ && biases_->getWGrad());
+ gradScaleShift_ = Matrix::create(2, oc_, false, false);
+ gradScaleShift_->zeroMem();
+ const VectorPtr& wgt = parameters_[0]->getBuf(PARAMETER_GRADIENT);
+ const VectorPtr& bias = biasParameter_->getBuf(PARAMETER_GRADIENT);
+ wgt->setData(gradScaleShift_->getData());
+ bias->setData(gradScaleShift_->getData() + oc_);
+ }
+}
+
+void MKLDNNBatchNormLayer::convertWeightsFromPaddle() {
+ if (hasInitedWgt_) {
+ return;
+ }
+ // prepare mean and var if necessary
+ if (useGlobalStats_) {
+ CHECK(mean_);
+ CHECK(var_);
+ mean_->copyFrom(*(movingMean_->getW()));
+ var_->copyFrom(*(movingVar_->getW()));
+ }
+ hasInitedWgt_ = true;
+}
+
+void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
+ // calculating and saving moving mean and variance
+ CHECK_EQ(useGlobalStats_, false);
+ movingMean_->getW()->add(
+ *mean_, movingAvgFraction_, 1.0 - movingAvgFraction_);
+ // here var is v^2
+ movingVar_->getW()->add(*var_, movingAvgFraction_, 1.0 - movingAvgFraction_);
+}
+
+void MKLDNNBatchNormLayer::reshape(
+ int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
+ reshapeInput(bs, ih, iw);
+ oh = ih;
+ ow = ow;
+ // ic_ and oc can not be changed
+ CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
+ << "Input channel can not be changed";
+ reshapeOutput(oh, ow);
+ resizeOutput(bs, oc * oh * ow);
+ printSizeInfo();
+}
+
+void MKLDNNBatchNormLayer::resetFwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) {
+ // In training phase, it will always calculate mean and var,
+ // so useGlobalStats must be false.
+ // In scoring phase, it depends on useGlobalStats choice.
+ if (passType_ != PASS_TEST && useGlobalStats_ == true) {
+ LOG(WARNING) << "use_global_stats is invalid setting in training phase";
+ useGlobalStats_ = false;
+ }
+
+ resetFwdBuffers(in, wgt, out);
+
+ resetFwdPD(fwdPD_, in, wgt, out);
+
+ resetFwdPipeline(pipeline, fwdPD_, in, wgt, out);
+}
+
+void MKLDNNBatchNormLayer::resetBwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) {
+ std::shared_ptr pd;
+
+ resetBwdBuffers(in, wgt, out);
+
+ resetBwdPD(pd, in, wgt, out);
+
+ resetBwdPipeline(pipeline, pd, in, wgt, out);
+}
+
+void MKLDNNBatchNormLayer::forward(PassType passType) {
+ MKLDNNLayer::forward(passType);
+
+ // calculate and save moving mean and variance
+ if (passType_ != PASS_TEST) {
+ calMovingMeanAndVar();
+ }
+}
+
+void MKLDNNBatchNormLayer::updateWeights(const UpdateCallback& callback) {
+ weight_->getParameterPtr()->incUpdate(callback);
+ if (biases_ && biases_->getWGrad()) {
+ biases_->getParameterPtr()->incUpdate(callback);
+ }
+}
+
+void MKLDNNBatchNormLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ resetInValue(in);
+
+ memory::dims outDims = memory::dims{bs_, oc_, oh_, ow_};
+ CHECK(in);
+ auto outPD =
+ MKLDNNMatrix::createPrimitiveDesc(outDims, in->getFormat(), engine_);
+ resetOutValue(out, outPD);
+
+ if (valueScaleShift_) {
+ auto pd = MKLDNNMatrix::createPrimitiveDesc({2, oc_}, format::nc, engine_);
+ resetWithMatrix(wgt, valueScaleShift_, pd);
+ }
+ if (passType_ != PASS_TEST || useGlobalStats_) {
+ auto pd = MKLDNNMatrix::createPrimitiveDesc({oc_}, format::x, engine_);
+ mean_ = MKLDNNMatrix::create(pd);
+ var_ = MKLDNNMatrix::create(pd);
+ }
+}
+
+void MKLDNNBatchNormLayer::resetFwdPD(
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr in,
+ MKLDNNMatrixPtr wgt,
+ MKLDNNMatrixPtr out) {
+ flags_ = 0u;
+ prop_kind pk = passType_ == PASS_TEST ? prop_kind::forward_scoring
+ : prop_kind::forward_training;
+ if (useGlobalStats_) {
+ flags_ = (flags_ | batch_normalization_flag::use_global_stats);
+ }
+ if (wgt) {
+ flags_ = (flags_ | batch_normalization_flag::use_scale_shift);
+ }
+ auto fwdDesc = bn_fwd::desc(pk, in->getMemoryDesc(), EPS, flags_);
+ pd.reset(new bn_fwd::primitive_desc(fwdDesc, engine_));
+ CHECK_PRIMITIVE_DESC_EQ(out, pd->dst_primitive_desc());
+ if (wgt) {
+ CHECK_PRIMITIVE_DESC_EQ(wgt, pd->weights_primitive_desc());
+ }
+ if (passType_ != PASS_TEST || useGlobalStats_) {
+ CHECK_PRIMITIVE_DESC_EQ(mean_, pd->mean_primitive_desc());
+ CHECK_PRIMITIVE_DESC_EQ(var_, pd->variance_primitive_desc());
+ }
+}
+
+void MKLDNNBatchNormLayer::resetFwdPipeline(
+ std::vector& pipeline,
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ if (passType_ == PASS_TEST) {
+ if (useGlobalStats_) {
+ fwd_.reset(wgt != nullptr ? new bn_fwd(*pd,
+ *in,
+ (const primitive::at)(*mean_),
+ (const primitive::at)(*var_),
+ *wgt,
+ *out)
+ : new bn_fwd(*pd,
+ *in,
+ (const primitive::at)(*mean_),
+ (const primitive::at)(*var_),
+ *out));
+ } else {
+ fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out)
+ : new bn_fwd(*pd, *in, *out));
+ }
+ } else {
+ CHECK_EQ(useGlobalStats_, false)
+ << "useGlobalStats should be false in training";
+ fwd_.reset(wgt != nullptr ? new bn_fwd(*pd, *in, *wgt, *out, *mean_, *var_)
+ : new bn_fwd(*pd, *in, *out, *mean_, *var_));
+ }
+ pipeline.push_back(*fwd_);
+}
+
+void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ CHECK(inVal_ && outVal_);
+ resetOutGrad(out, outVal_->getPrimitiveDesc());
+ resetInGrad(in, inVal_->getPrimitiveDesc());
+ if (gradScaleShift_) {
+ CHECK(wgtVal_);
+ resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc());
+ }
+}
+
+void MKLDNNBatchNormLayer::resetBwdPD(
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ pd = nullptr;
+ if (in == nullptr) {
+ return;
+ }
+ CHECK_PRIMITIVE_DESC_EQ(out, in->getPrimitiveDesc());
+ auto md = in->getMemoryDesc();
+ auto bwdDesc = bn_bwd::desc(prop_kind::backward, md, md, EPS, flags_);
+ pd.reset(new bn_bwd::primitive_desc(bwdDesc, engine_, *fwdPD_));
+ CHECK(pd->weights_primitive_desc() == fwdPD_->weights_primitive_desc());
+ CHECK_PRIMITIVE_DESC_EQ(wgt, pd->diff_weights_primitive_desc());
+ CHECK_PRIMITIVE_DESC_EQ(mean_, pd->mean_primitive_desc());
+ CHECK_PRIMITIVE_DESC_EQ(var_, pd->variance_primitive_desc());
+}
+
+void MKLDNNBatchNormLayer::resetBwdPipeline(
+ std::vector& pipeline,
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out) {
+ if (pd == nullptr) {
+ return;
+ }
+ CHECK(inVal_);
+ bwdData_.reset(
+ wgt && wgtVal_
+ ? new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *wgtVal_, *in, *wgt)
+ : new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *in));
+ pipeline.push_back(*bwdData_);
+}
+
+} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.h b/paddle/gserver/layers/MKLDNNBatchNormLayer.h
new file mode 100644
index 0000000000..456c0424ec
--- /dev/null
+++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.h
@@ -0,0 +1,138 @@
+/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "MKLDNNLayer.h"
+#include "mkldnn.hpp"
+
+namespace paddle {
+typedef mkldnn::batch_normalization_forward bn_fwd;
+typedef mkldnn::batch_normalization_backward bn_bwd;
+
+/**
+ * @brief A subclass of MKLDNNLayer BatchNorm layer.
+ *
+ * The config file api is mkldnn_batch_norm
+ */
+class MKLDNNBatchNormLayer : public MKLDNNLayer {
+protected:
+ // save forward primitive_desc, which can be used backward
+ std::shared_ptr fwdPD_;
+
+ // Epsilon value used in the batch normalization formula.
+ static const real EPS;
+ // weight and bias in paddle
+ std::unique_ptr weight_;
+ std::unique_ptr biases_;
+ // mkldnn use a large buffer store both scale and shift
+ // which are weight and bias in paddle corresponding.
+ MatrixPtr valueScaleShift_;
+ MatrixPtr gradScaleShift_;
+ // Moving average of mean.
+ std::unique_ptr movingMean_;
+ // Moving average of variance.
+ std::unique_ptr movingVar_;
+
+ // if useGlobalStats_ is true, will use the loaded mean and variance.
+ // otherwise, calculate mean and variance in every mini-batch.
+ bool useGlobalStats_;
+ // used in MKLDNN primitive desc
+ unsigned flags_;
+ // use to compute moving mean and variance.
+ real movingAvgFraction_;
+ // whether the weight has been init
+ bool hasInitedWgt_;
+
+ // local mean and variance
+ // when useGlobalStats_ they are loaded from moving mean and variance
+ // when do not useGlobalStats_ they are calculated from this mini-batch
+ MKLDNNMatrixPtr mean_;
+ MKLDNNMatrixPtr var_;
+
+public:
+ explicit MKLDNNBatchNormLayer(const LayerConfig& config)
+ : MKLDNNLayer(config), useGlobalStats_(true), hasInitedWgt_(false) {}
+
+ ~MKLDNNBatchNormLayer() {}
+
+ bool init(const LayerMap& layerMap,
+ const ParameterMap& parameterMap) override;
+
+ void forward(PassType passType) override;
+
+ void reshape(
+ int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override;
+
+ void resetFwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) override;
+
+ void resetBwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) override;
+
+ void updateWeights(const UpdateCallback& callback) override;
+
+ void convertWeightsFromPaddle() override;
+
+protected:
+ void initWeight();
+ /**
+ * cal moving mean and variance.
+ * moving = moving * AvgFraction + local * (1 - AvgFraction)
+ */
+ void calMovingMeanAndVar();
+ /**
+ * Forward functions: reset buffers(input, weight, output),
+ * reset primitive descriptor,
+ * reset pipeline.
+ */
+ void resetFwdBuffers(MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+ void resetFwdPD(std::shared_ptr& pd,
+ MKLDNNMatrixPtr in,
+ MKLDNNMatrixPtr wgt,
+ MKLDNNMatrixPtr out);
+ void resetFwdPipeline(std::vector& pipeline,
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+
+ /**
+ * Backward functions: reset buffers(input, weight, output),
+ * reset primitive descriptor,
+ * reset pipeline.
+ */
+ void resetBwdBuffers(MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+ void resetBwdPD(std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+ void resetBwdPipeline(std::vector& pipeline,
+ std::shared_ptr& pd,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& out);
+};
+
+} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp
index 83f4e4e615..b8120eda1e 100644
--- a/paddle/gserver/layers/MKLDNNConvLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp
@@ -262,12 +262,15 @@ void MKLDNNConvLayer::resetBwdWgtPD(
padR,
padKind);
pd.reset(new conv_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_));
- CHECK(pd->src_primitive_desc() == inVal_->getPrimitiveDesc())
- << "primitive desc of in value should equal";
- CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc())
- << "primitive desc of out grad should equal the out value";
- CHECK(pd->diff_weights_primitive_desc() == wgtVal_->getPrimitiveDesc())
- << "primitive desc of weight grad should equal the weight value";
+ CHECK_PRIMITIVE_DESC_EQ(inVal_, pd->src_primitive_desc());
+ CHECK_PRIMITIVE_DESC_EQ(
+ outVal_,
+ pd->diff_dst_primitive_desc(),
+ "primitive desc of out value and grad should be equal");
+ CHECK_PRIMITIVE_DESC_EQ(
+ wgtVal_,
+ pd->diff_weights_primitive_desc(),
+ "primitive desc of weight value and grad should be equal");
}
void MKLDNNConvLayer::resetBwdDataPD(
@@ -292,10 +295,14 @@ void MKLDNNConvLayer::resetBwdDataPD(
padR,
padding_kind::zero);
pd.reset(new conv_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_));
- CHECK(pd->diff_src_primitive_desc() == inVal_->getPrimitiveDesc())
- << "primitive desc of in grad should equal the in value";
- CHECK(pd->diff_dst_primitive_desc() == outVal_->getPrimitiveDesc())
- << "primitive desc of out grad should equal";
+ CHECK_PRIMITIVE_DESC_EQ(
+ inVal_,
+ pd->diff_src_primitive_desc(),
+ "primitive desc of in value and grad should be equal");
+ CHECK_PRIMITIVE_DESC_EQ(
+ outVal_,
+ pd->diff_dst_primitive_desc(),
+ "primitive desc of out value and grad should be equal");
}
void MKLDNNConvLayer::resetBwdBuffers(
@@ -310,17 +317,20 @@ void MKLDNNConvLayer::resetBwdBuffers(
resetWithMatrix(
wgt, weight_->getWGrad(), wgtPD->diff_weights_primitive_desc());
- CHECK(wgtVal_ != nullptr &&
- wgt->getPrimitiveDesc() == wgtVal_->getPrimitiveDesc())
- << "primitive desc of weight grad and value should be equal";
+ CHECK_PRIMITIVE_DESC_EQ(
+ wgtVal_,
+ wgt->getPrimitiveDesc(),
+ "primitive desc of weight grad and value should be equal");
bias = nullptr;
if (biases_ && biases_->getWGrad()) {
resetWithMatrix(
bias, biases_->getWGrad(), wgtPD->diff_bias_primitive_desc());
- CHECK(bias && biasVal_ &&
- bias->getPrimitiveDesc() == biasVal_->getPrimitiveDesc())
- << "primitive desc of bias grad should equal the bias value";
+ CHECK(bias);
+ CHECK_PRIMITIVE_DESC_EQ(
+ biasVal_,
+ bias->getPrimitiveDesc(),
+ "primitive desc of bias grad and value should be equal");
}
if (dataPD == nullptr) {
diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp
index 6bb19976b5..663a105098 100644
--- a/paddle/gserver/layers/MKLDNNLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNLayer.cpp
@@ -235,8 +235,7 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in,
in = MKLDNNMatrix::create(intPD, inMat);
Argument& arg = input->getOutput(this->getName());
arg.grad = std::dynamic_pointer_cast(in);
- CHECK(inVal_);
- CHECK(inVal_->getPrimitiveDesc() == intPD) << "the primitive desc must equal";
+ CHECK_PRIMITIVE_DESC_EQ(inVal_, intPD);
if (inputIsOnlyMKLDNN()) {
return;
}
@@ -250,8 +249,7 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in,
CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat()))
<< "should have external input value and the format must be nchw(nc)";
extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat);
- CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD)
- << "should have internal input value and primitive desc must equal";
+ CHECK_PRIMITIVE_DESC_EQ(inVal_, intPD);
in = MKLDNNMatrix::create(intPD);
cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_);
CHECK(cvtInGrad_);
@@ -277,8 +275,7 @@ void MKLDNNLayer::resetOutGrad(MKLDNNMatrixPtr& out,
CHECK(extOutVal_ != nullptr && isPaddleFormat(extOutVal_->getFormat()))
<< "should have external output value and the format must be nchw(nc)";
extOutGrad_ = MKLDNNMatrix::create(extOutVal_->getPrimitiveDesc(), outMat);
- CHECK(outVal_ != nullptr && outVal_->getPrimitiveDesc() == intPD)
- << "should have internal output value and primitive desc must equal";
+ CHECK_PRIMITIVE_DESC_EQ(outVal_, intPD);
out = MKLDNNMatrix::create(intPD);
cvtOutGrad_ = MKLDNNMatrix::createReorder(extOutGrad_, out);
CHECK(cvtOutGrad_);
diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp
index 0a19fe2333..73b7e8857f 100644
--- a/paddle/gserver/tests/MKLDNNTester.cpp
+++ b/paddle/gserver/tests/MKLDNNTester.cpp
@@ -91,10 +91,16 @@ void MKLDNNTester::setInputImgSize() {
// init randome parameters of ref, and copy to mkldnn
void MKLDNNTester::randomWgtDatas() {
EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size());
+ const bool isBN = refLayer_->getType() == "batch_norm";
for (size_t i = 0; i < parameters_[REF].size(); ++i) {
const VectorPtr& dnnValue = parameters_[DNN][i]->getBuf(PARAMETER_VALUE);
const VectorPtr& refValue = parameters_[REF][i]->getBuf(PARAMETER_VALUE);
parameters_[REF][i]->randomize();
+ if (isBN && i == 2) {
+ // this param is moving average in batch norm, which must larger than 0
+ real offset = fabs(refValue->getMin()) + 1.0;
+ refValue->add(offset);
+ }
dnnValue->copyFrom(*refValue);
VLOG(MKLDNN_TESTS) << "Random weight " << parameters_[DNN][i]->getName();
@@ -132,8 +138,7 @@ void MKLDNNTester::checkForward() {
void MKLDNNTester::checkBackwardData() {
VLOG(MKLDNN_TESTS) << "Check Backward Data";
- // TODO(TJ): uncomment me when batch norm ready
- // const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm";
+ const bool isBN = refLayer_->getType() == "batch_norm";
for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) {
const MatrixPtr& dnnDiff = dataLayers_[DNN][i]->getOutputGrad();
const MatrixPtr& refDiff = dataLayers_[REF][i]->getOutputGrad();
@@ -144,11 +149,11 @@ void MKLDNNTester::checkBackwardData() {
double delta = compareMatrix(dnnDiff, refDiff);
EXPECT_LE(fabs(delta), eps_);
- // TODO(TJ): uncomment me when batch norm ready
- // if (isBN) {
- // // the other two inputs in batch norm are for moving mean and var
- // break;
- // }
+ if (isBN) {
+ // the other two inputs in batch norm are for moving mean and var
+ // do not have grad to compare
+ break;
+ }
}
}
@@ -308,10 +313,14 @@ double MKLDNNTester::compareVector(const VectorPtr& v1, const VectorPtr& v2) {
void MKLDNNTester::runOnce() {
// test forward
randomBotDatas();
- dnnLayer_->forward(PASS_TRAIN);
- refLayer_->forward(PASS_TRAIN);
+ dnnLayer_->forward(passType_);
+ refLayer_->forward(passType_);
checkForward();
+ if (passType_ == PASS_TEST) {
+ return;
+ }
+
// test backward
// simple updater
UpdateCallback updateCallback = [](Parameter* para) {
@@ -343,6 +352,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
size_t batchSize,
size_t inputImgH,
size_t inputImgW,
+ PassType passType,
bool printDetails,
size_t iter,
float epsilon) {
@@ -361,6 +371,7 @@ void MKLDNNTester::run(const TestConfig& dnn,
ih_ = inputImgH;
iw_ = inputImgW;
+ passType_ = passType;
log_ = printDetails;
iter_ = iter;
eps_ = epsilon;
diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h
index c385d1c727..19d8848f74 100644
--- a/paddle/gserver/tests/MKLDNNTester.h
+++ b/paddle/gserver/tests/MKLDNNTester.h
@@ -62,12 +62,15 @@ protected:
float eps_;
/// input image size, default 1
size_t ih_, iw_;
+ /// passType, PASS_TRAIN, PASS_TEST or PASS_GC (Gradient Check pass)
+ PassType passType_;
public:
explicit MKLDNNTester(size_t iter = 3, float epsilon = 1e-4) {
iter_ = iter;
eps_ = epsilon;
log_ = false;
+ passType_ = PASS_TRAIN;
}
~MKLDNNTester() {}
@@ -78,6 +81,7 @@ public:
size_t batchSize,
size_t inputImgH = 1,
size_t inputImgW = 1,
+ PassType passType = PASS_TRAIN,
bool printDetails = false,
size_t iter = 3,
float epsilon = 1e-4);
diff --git a/paddle/gserver/tests/test_MKLDNN.cpp b/paddle/gserver/tests/test_MKLDNN.cpp
index 6cb4ca5e08..85d4f437c2 100644
--- a/paddle/gserver/tests/test_MKLDNN.cpp
+++ b/paddle/gserver/tests/test_MKLDNN.cpp
@@ -212,6 +212,66 @@ TEST(MKLDNNLayer, PoolLayer) {
testPoolLayer({2, 8, 56, 56, 29, 29, 3, 3, 1, 1, 2, 2});
}
+struct testBatchNormDesc {
+ int bs;
+ int ic;
+ int ih, iw;
+};
+
+static void getMKLDNNBatchNormConfig(TestConfig& cfg,
+ const testBatchNormDesc& pm) {
+ cfg.layerConfig.set_size(pm.ic * pm.ih * pm.iw);
+ cfg.layerConfig.set_type("mkldnn_batch_norm");
+ cfg.biasSize = pm.ic;
+ cfg.inputDefs.push_back(
+ {INPUT_DATA,
+ "layer_0",
+ /* size of input layer= */ size_t(pm.ic * pm.ih * pm.iw),
+ /* size of weight= */ size_t(pm.ic)});
+ cfg.inputDefs.push_back(
+ {INPUT_DATA, "layer_1_moving_mean", 1, size_t(pm.ic)});
+ cfg.inputDefs.back().isStatic = true;
+ cfg.inputDefs.push_back({INPUT_DATA, "layer_2_moving_var", 1, size_t(pm.ic)});
+ cfg.inputDefs.back().isStatic = true;
+ LayerInputConfig* input = cfg.layerConfig.add_inputs();
+ // TODO(TJ): uncomment me when refine and support comparing all zeroes vector
+ // cfg.layerConfig.set_active_type("relu");
+ cfg.layerConfig.add_inputs();
+ cfg.layerConfig.add_inputs();
+ ImageConfig* img_conf = input->mutable_image_conf();
+ img_conf->set_channels(pm.ic);
+ img_conf->set_img_size_y(pm.ih);
+ img_conf->set_img_size(pm.iw);
+}
+
+void testBatchNormLayer(const testBatchNormDesc& pm) {
+ TestConfig dnnConfig;
+ getMKLDNNBatchNormConfig(dnnConfig, pm);
+ TestConfig refConfig = dnnConfig;
+ refConfig.layerConfig.set_type("batch_norm");
+ // for PASS_TRAIN, use_global_stats always should be false, and batchsize != 1
+ VLOG(MKLDNN_TESTS) << "check train phase";
+ dnnConfig.layerConfig.set_use_global_stats(false);
+ refConfig.layerConfig.set_use_global_stats(false);
+ MKLDNNTester tester;
+ tester.run(dnnConfig, refConfig, pm.bs, pm.ih, pm.iw, PASS_TRAIN);
+ // for PASS_TEST, check use_global_stats true and false, and batchsize 1
+ VLOG(MKLDNN_TESTS) << "check test phase";
+ for (auto useGS : {false, true}) {
+ dnnConfig.layerConfig.set_use_global_stats(useGS);
+ refConfig.layerConfig.set_use_global_stats(useGS);
+ MKLDNNTester tester;
+ for (auto bs : {pm.bs, 1}) {
+ tester.run(dnnConfig, refConfig, bs, pm.ih, pm.iw, PASS_TEST);
+ }
+ }
+}
+
+TEST(MKLDNNLayer, BatchNormLayer) {
+ testBatchNormLayer({4, 10, 6, 6});
+ testBatchNormLayer({16, 32, 16, 16});
+}
+
struct testActDesc {
int bs, ic, ih, iw;
};
diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h
index fe755d096d..5f5b819017 100644
--- a/paddle/math/MKLDNNMatrix.h
+++ b/paddle/math/MKLDNNMatrix.h
@@ -24,6 +24,12 @@ namespace paddle {
class MKLDNNMatrix;
typedef std::shared_ptr MKLDNNMatrixPtr;
+#define CHECK_PRIMITIVE_DESC_EQ(MAT, PD, ...) \
+ CHECK(MAT) << " can not be empty."; \
+ CHECK(MAT->getPrimitiveDesc() == PD) \
+ << #MAT "->getPrimitiveDesc() and " #PD " should be equal.\n " \
+ << "" __VA_ARGS__;
+
/**
* @brief MKLDNN Matrix.
*
@@ -91,6 +97,11 @@ public:
const MKLDNNMatrixPtr& dst,
bool checkData = true);
+ void copyFrom(const Matrix& src) {
+ // TODO(TJ): reorder data if this format is not nchw or x
+ m_->copyFrom(src);
+ }
+
public:
/**
* Reorder this MKLDNNMatrix from other format.
diff --git a/paddle/math/RowBuffer.h b/paddle/math/RowBuffer.h
index 9ef5b89680..e457d71f1b 100644
--- a/paddle/math/RowBuffer.h
+++ b/paddle/math/RowBuffer.h
@@ -60,7 +60,7 @@ public:
*/
inline real* get(int row) const {
if (preallocatedBuf_) {
- CHECK_LE((row + 1) * width_ * sizeof(real), preallocatedBuf_->getSize());
+ CHECK_LE((row)*width_ * sizeof(real), preallocatedBuf_->getSize());
return reinterpret_cast(preallocatedBuf_->getBuf()) + row * width_;
} else {
CHECK_LE((row + 1) * width_, rowStore_.size());
diff --git a/paddle/memory/memcpy.h b/paddle/memory/memcpy.h
index 9b36182c2b..29c20e1860 100644
--- a/paddle/memory/memcpy.h
+++ b/paddle/memory/memcpy.h
@@ -54,6 +54,5 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);
#endif
-
} // namespace memory
} // namespace paddle
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index d2d70d8be7..c722617101 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -69,6 +69,13 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP(max_pool2d_with_index);\n")
endif()
+ # pool_cudnn_op contains several operators
+ if ("${TARGET}" STREQUAL "pool_cudnn_op")
+ set(pybind_flag 1)
+ # It's enough to just adding one operator to pybind
+ file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
+ endif()
+
# save_restore_op contains several operators
if ("${TARGET}" STREQUAL "save_restore_op")
set(pybind_flag 1)
@@ -82,7 +89,7 @@ function(op_library TARGET)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
-
+
# reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1)
@@ -123,6 +130,7 @@ set(DEPS_OPS
sum_op
pool_op
pool_with_index_op
+ sequence_conv_op
lstm_op)
@@ -131,9 +139,10 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
-op_library(sum_op DEPS net_op)
+op_library(sum_op DEPS net_op selected_rows_functor)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
+op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
@@ -148,3 +157,4 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
+cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc
index ee4f9b0ef2..90f1535fcd 100644
--- a/paddle/operators/activation_op.cc
+++ b/paddle/operators/activation_op.cc
@@ -446,12 +446,16 @@ REGISTER_OP(thresholded_relu, ops::ActivationOp,
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
hard_sigmoid_grad, ops::ActivationOpGrad);
-#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
- REGISTER_OP_CPU_KERNEL( \
- act_type, \
- ops::ActivationKernel>); \
- REGISTER_OP_CPU_KERNEL(act_type##_grad, \
- ops::ActivationGradKernel>);
+#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
+ REGISTER_OP_CPU_KERNEL( \
+ act_type, \
+ ops::ActivationKernel>, \
+ ops::ActivationKernel>); \
+ REGISTER_OP_CPU_KERNEL( \
+ act_type##_grad, ops::ActivationGradKernel>, \
+ ops::ActivationGradKernel>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu
index 7b7644519d..97737857ab 100644
--- a/paddle/operators/activation_op.cu
+++ b/paddle/operators/activation_op.cu
@@ -17,12 +17,16 @@
namespace ops = paddle::operators;
-#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
- REGISTER_OP_GPU_KERNEL( \
- act_type, \
- ops::ActivationKernel>); \
- REGISTER_OP_GPU_KERNEL(act_type##_grad, \
- ops::ActivationGradKernel>);
+#define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
+ REGISTER_OP_GPU_KERNEL( \
+ act_type, \
+ ops::ActivationKernel>, \
+ ops::ActivationKernel>); \
+ REGISTER_OP_GPU_KERNEL( \
+ act_type##_grad, ops::ActivationGradKernel>, \
+ ops::ActivationGradKernel>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h
index 4f4eb44fed..e4c6b2e09c 100644
--- a/paddle/operators/activation_op.h
+++ b/paddle/operators/activation_op.h
@@ -210,8 +210,8 @@ struct HardShrinkFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y) const {
- auto temp1 = (x < (threshold * -1)).template cast().eval();
- auto temp2 = (x > threshold).template cast().eval();
+ auto temp1 = (x < static_cast(threshold * -1)).template cast().eval();
+ auto temp2 = (x > static_cast(threshold)).template cast().eval();
y.device(d) = x * (temp1 + temp2);
}
};
@@ -226,8 +226,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- auto temp1 = (x < (threshold * -1)).template cast().eval();
- auto temp2 = (x > threshold).template cast().eval();
+ auto temp1 = (x < static_cast(threshold * -1)).template cast().eval();
+ auto temp2 = (x > static_cast(threshold)).template cast().eval();
dx.device(d) = dy * (temp1 + temp2).template cast();
}
};
@@ -243,9 +243,10 @@ struct SoftShrinkFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- auto temp1 = (x > lambda).template cast().eval();
- auto temp2 = (x < -lambda).template cast().eval();
- y.device(d) = temp1 * (x - lambda) + temp2 * (x + lambda);
+ auto lambdaT = static_cast(lambda);
+ auto temp1 = (x > lambdaT).template cast().eval();
+ auto temp2 = (x < -lambdaT).template cast().eval();
+ y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
}
};
@@ -257,8 +258,9 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- auto temp1 = (x > lambda).template cast().eval();
- auto temp2 = (x < -lambda).template cast().eval();
+ auto lambdaT = static_cast(lambda);
+ auto temp1 = (x > lambdaT).template cast().eval();
+ auto temp2 = (x < -lambdaT).template cast().eval();
dx.device(d) = dy * (temp1 + temp2).template cast();
}
};
@@ -362,7 +364,8 @@ struct BReluFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- y.device(d) = x.cwiseMax(t_min).cwiseMin(t_max);
+ y.device(d) =
+ x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max));
}
};
@@ -375,7 +378,9 @@ struct BReluGradFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- dx.device(d) = dy * ((x > t_min) * (x < t_max)).template cast();
+ dx.device(d) = dy *
+ ((x > static_cast(t_min)) * (x < static_cast(t_max)))
+ .template cast();
}
};
@@ -390,7 +395,8 @@ struct Relu6Functor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- y.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(threshold);
+ y.device(d) =
+ x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold));
}
};
@@ -402,8 +408,9 @@ struct Relu6GradFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- dx.device(d) =
- dy * ((x > static_cast(0)) * (x < threshold)).template cast();
+ dx.device(d) = dy *
+ ((x > static_cast(0)) * (x < static_cast(threshold)))
+ .template cast();
}
};
@@ -463,7 +470,8 @@ struct SoftReluFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- auto temp = x.cwiseMax(-threshold).cwiseMin(threshold);
+ auto tmp = static_cast(threshold);
+ auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
y.device(d) = (static_cast(1) + temp.exp()).log();
}
};
@@ -476,7 +484,8 @@ struct SoftReluGradFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- auto temp = ((x > -threshold) * (x < threshold)).template cast().eval();
+ auto tmp = static_cast(threshold);
+ auto temp = ((x > -tmp) * (x < tmp)).template cast().eval();
dx.device(d) = dy * (static_cast(1) - (-y).exp()) * temp;
}
};
@@ -490,7 +499,7 @@ struct LeakyReluFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- y.device(d) = x.cwiseMax(alpha * x);
+ y.device(d) = x.cwiseMax(static_cast(alpha) * x);
}
};
@@ -502,7 +511,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor {
}
template
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
- auto temp1 = alpha * (x < static_cast(0)).template cast().eval();
+ auto temp1 = static_cast(alpha) *
+ (x < static_cast(0)).template cast().eval();
auto temp2 = (x >= static_cast(0)).template cast().eval();
dx.device(d) = dy * (temp1 + temp2).template cast();
}
@@ -517,9 +527,9 @@ struct ELUFunctor : public BaseActivationFunctor {
template
void operator()(Device d, X x, Y y) const {
- y.device(d) =
- x.cwiseMax(static_cast(0)) +
- (alpha * (x.exp() - static_cast(1))).cwiseMin(static_cast(0));
+ y.device(d) = x.cwiseMax(static_cast