diff --git a/CMakeLists.txt b/CMakeLists.txt
index 23bb27e77b..db3c3b8e20 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -103,6 +103,11 @@ if(ANDROID OR IOS)
add_definitions(-DPADDLE_MOBILE_INFERENCE)
endif()
+if (APPLE OR WIN32)
+ set(WITH_MKL OFF CACHE STRING
+ "Disable MKL for building on mac and windows" FORCE)
+endif()
+
set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING
"A path setting third party libraries download & build directories.")
diff --git a/cmake/external/anakin.cmake b/cmake/external/anakin.cmake
index d205e39582..fb3d8ef8d5 100644
--- a/cmake/external/anakin.cmake
+++ b/cmake/external/anakin.cmake
@@ -7,7 +7,17 @@ set(ANAKIN_INSTALL_DIR "${THIRD_PARTY_PATH}/install/anakin" CACHE PATH
set(ANAKIN_INCLUDE "${ANAKIN_INSTALL_DIR}" CACHE STRING "root of Anakin header files")
set(ANAKIN_LIBRARY "${ANAKIN_INSTALL_DIR}" CACHE STRING "path of Anakin library")
-set(ANAKIN_COMPILE_EXTRA_FLAGS -Wno-error=unused-variable -Wno-error=format-extra-args -Wno-error=comment -Wno-error=format -Wno-error=switch -Wno-error=return-type -Wno-error=non-virtual-dtor -Wno-reorder -Wno-error=cpp)
+set(ANAKIN_COMPILE_EXTRA_FLAGS
+ -Wno-error=unused-variable -Wno-unused-variable
+ -Wno-error=format-extra-args -Wno-format-extra-args
+ -Wno-error=comment -Wno-comment
+ -Wno-error=format -Wno-format
+ -Wno-error=switch -Wno-switch
+ -Wno-error=return-type -Wno-return-type
+ -Wno-error=non-virtual-dtor -Wno-non-virtual-dtor
+ -Wno-sign-compare
+ -Wno-reorder
+ -Wno-error=cpp)
set(ANAKIN_LIBRARY_URL "https://github.com/pangge/Anakin/releases/download/3.0/anakin_release_simple.tar.gz")
diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake
index 85f40585da..82437a8424 100644
--- a/cmake/external/grpc.cmake
+++ b/cmake/external/grpc.cmake
@@ -50,6 +50,7 @@ ExternalProject_Add(
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_IN_SOURCE 1
+ PATCH_COMMAND git apply ${PADDLE_SOURCE_DIR}/patches/grpc/fix_too_early_destory.patch
# NOTE(yuyang18):
# Disable -Werror, otherwise the compile will fail in MacOS.
# It seems that we cannot configure that by make command.
diff --git a/cmake/version.cmake b/cmake/version.cmake
index cde650128a..79b8e8ac49 100644
--- a/cmake/version.cmake
+++ b/cmake/version.cmake
@@ -1,16 +1,21 @@
# Get the latest git tag.
set(PADDLE_VERSION $ENV{PADDLE_VERSION})
set(tmp_version "HEAD")
+set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
+set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+")
while ("${PADDLE_VERSION}" STREQUAL "")
execute_process(
- COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 ${tmp_version}
+ COMMAND ${GIT_EXECUTABLE} describe --tags --abbrev=0 --always ${tmp_version}
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
OUTPUT_VARIABLE GIT_TAG_NAME
RESULT_VARIABLE GIT_RESULT
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE)
if (NOT ${GIT_RESULT})
# Check the tag is a correct version
- if (${GIT_TAG_NAME} MATCHES "v[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?")
+ if (${GIT_TAG_NAME} MATCHES "${COMMIT_VERSION_REGEX}")
+ # if no tag was found, set PADDLE_VERSION to latest
+ set(PADDLE_VERSION "latest")
+ elseif (${GIT_TAG_NAME} MATCHES "v${TAG_VERSION_REGEX}")
string(REPLACE "v" "" PADDLE_VERSION ${GIT_TAG_NAME})
else() # otherwise, get the previous git tag name.
set(tmp_version "${GIT_TAG_NAME}~1")
diff --git a/doc/fluid/design/quantization/fixed_point_quantization.md b/doc/fluid/design/quantization/fixed_point_quantization.md
new file mode 100644
index 0000000000..085352fc56
--- /dev/null
+++ b/doc/fluid/design/quantization/fixed_point_quantization.md
@@ -0,0 +1,110 @@
+Fixed-point quantization uses lower bits, for example, 2-bit, 3-bit or 8-bit fixed point to represent weights and activations, which usually are in singe-precision float-point with 32 bits. The fixed-point representation has advantages in reducing memory bandwidth, lowering power consumption and computational resources as well as the model storage requirements. It is especially important for the inference in embedded-device deployment.
+
+According to some experiments, the apporach to quantize the model trained in float point directly works effectively on the large models, like the VGG model having many parameters. But the accuracy drops a lot for the small model. In order to improve the tradeoff between accuracy and latency, many quantized training apporaches are proposed.
+
+This document is to design a quantized training framework on Fluid. The first part will introduce how to quantize, The second part will describe the quantized training framework. The last part will illustrate how to calculate the quantization scale.
+
+
+### How to quantize
+
+There are many ways to quantize the float value to fixed-point value. For example:
+
+$$ r = min(max(x, a), b)$$
+$$ s = \frac{b - a}{n - 1} $$
+$$ q = \left \lfloor \frac{r - a}{s} \right \rceil $$
+
+where, $x$ is the float value to be quantized, $[a, b]$ is the quantization range, $a$ is the minimum value and $b$ is the maximal value. $\left \lfloor \right \rceil$ denotes rounding to the nearest integer. If the quantization level is $k$, $n$ is $2^k$, for example, $k$ is 8 and $n$ is 256. $q$ is the quantized integer.
+
+
+The quantization we applied is parameterized by the number of quantization levels and maximum absolute value:
+
+$$ M = max(abs(x)) $$
+$$ q = \left \lfloor \frac{x}{M} * (n - 1) \right \rceil $$
+
+where, $x$ is the float value to be quantized, $M$ is maximum absolute value. $\left \lfloor \right \rceil$ denotes rounding to the nearest integer. For 8 bit quantization, $n=2^{8}=256$. $q$ is the quantized integer.
+
+
+Wether the *min-max* quantization or *max-abs* quantization, they also can be represent:
+
+$q = scale * r + b$
+
+We call *min-max*, *max-abs* as the quantization arguments, also call them quantization scale or quantization range.
+
+
+How to calculate the quantization scale (or maximum absolute value) for inference will be described in the last part.
+
+
+### Training Framework
+
+#### Forward pass
+
+The forward pass is simulated quantization, see Figure 1.
+
+The training framework is as following figure.
+
+
+
+Figure 1. Forward in training with simulated quantization.
+
+
+- Firstly, both input and weight will be quantized to 8-bit integers.
+- Second, do the multiplication (or convolution) operation with integers.
+- Third, dequantize the multiplication (or convolution) results to 32-bit float point.
+- Finally, do bias-addition in float type of 32 bit. Here, the bias is not quantized.
+
+For general matrix multiplication (GEMM), quantize for $X$ and $W$:
+
+$$ X_q = \left \lfloor \frac{X}{X_m} * (n - 1) \right \rceil $$
+$$ W_q = \left \lfloor \frac{W}{W_m} * (n - 1) \right \rceil $$
+
+Do GEMM:
+
+$$ Y = X_q * W_q $$
+
+
+Dequantize $Y$:
+
+$$
+\begin{align}
+Y_{dq} &=\frac{Y}{(n - 1) * (n - 1)} * X_m * W_m \\\
+ &=\frac{X_q * W_q}{(n - 1) * (n - 1)} * X_m * W_m \\\
+ &=(\frac{X_q}{n - 1} * X_m) * (\frac{W_q}{n - 1} * W_m)
+\end{align}
+$$
+
+From these formulas, dequantization also can be moved before GEMM, do dequantization for $Xq$ and $Wq$ at first, then do GEMM. The forward workflow in training is equivalent to following framework.
+
+
+
+Figure 2. Equivalent forward in training with simulated quantization.
+
+
+We use this equivalent workflow in the training. In our desigin, there is a quantization transpiler to insert the quantization operator and the de-quantization operator in the Fluid `ProgramDesc`. Since the outputs of quantization and de-quantization operator are still in floating point, they are called faked quantization and de-quantization operator. And the training framework is called simulated quantization.
+
+#### Backward pass
+
+See Figure 3. The gradients are calculated by dequantized weights and activations. All inputs and outputs are float point with 32-bit. And in the weight updating process, the gradients will be added to the original weight, not the quantized or dequantized weights.
+
+
+
+Figure 3. Backward and weight updating in training with simulated quantization.
+
+
+So the quantization transipler will change some inputs of the corresponding backward operators.
+
+### How to calculate quantization scale
+
+There are two strategies to calculate quantization scale, we call them dynamic and static strategy. The dynamic strategy calculates the quantization scale value each iteration. The static strategy keeps the quantization scale for different inputs.
+
+For weights, we apply the dynamic strategy in the training, that is to say, the quantization scale will be recalculated during each iteration until the traning is finished.
+
+For activations, the quantization scales are estimated during training, then used in inference. There are several different ways to estimate them:
+
+
+1. Calculate the mean of maximum absolute during a window.
+2. Calculate the max of maximum absolute during a window.
+3. Calculate the running mean of maximum absolute during a window, as follows:
+
+ $$ Vt = (1 - k) * V + k * V_{t-1} $$
+
+ where, $V$ is the maximum absolute value of current batch, $Vt$ is the running mean value. $k$ is a factor, such as 0.9.
diff --git a/doc/fluid/design/quantization/quantization_backward_and_optimization.png b/doc/fluid/design/quantization/quantization_backward_and_optimization.png
new file mode 100644
index 0000000000..84f8235ab8
Binary files /dev/null and b/doc/fluid/design/quantization/quantization_backward_and_optimization.png differ
diff --git a/doc/fluid/design/quantization/quantization_equivalent_forward.png b/doc/fluid/design/quantization/quantization_equivalent_forward.png
new file mode 100644
index 0000000000..df49c86453
Binary files /dev/null and b/doc/fluid/design/quantization/quantization_equivalent_forward.png differ
diff --git a/doc/fluid/design/quantization/quantization_forward.png b/doc/fluid/design/quantization/quantization_forward.png
new file mode 100644
index 0000000000..0913f61621
Binary files /dev/null and b/doc/fluid/design/quantization/quantization_forward.png differ
diff --git a/doc/v2/howto/capi/workflow_of_capi_cn.md b/doc/v2/howto/capi/workflow_of_capi_cn.md
index 3acdbae28e..db1568a2af 100644
--- a/doc/v2/howto/capi/workflow_of_capi_cn.md
+++ b/doc/v2/howto/capi/workflow_of_capi_cn.md
@@ -28,9 +28,9 @@
### 准备预测模型
-准备预测模型部分,我们以手写数字识别任务为例进行介绍。手写数字识别任务定义了一个含有[两个隐层的简单全连接网络](https://github.com/PaddlePaddle/book/blob/develop/02.recognize_digits/README.cn.md#softmax回归softmax-regression),网络接受一幅图片作为输入,将图片分类到 0 ~ 9 类别标签之一。完整代码可以查看[此目录](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense) 中的相关脚本。
+准备预测模型部分,我们以手写数字识别任务为例进行介绍。手写数字识别任务定义了一个含有[两个隐层的简单全连接网络](https://github.com/PaddlePaddle/book/blob/develop/02.recognize_digits/README.cn.md#softmax回归softmax-regression),网络接受一幅图片作为输入,将图片分类到 0 ~ 9 类别标签之一。完整代码可以查看[此目录](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense) 中的相关脚本。
-调用C-API开发预测程序需要一个训练好的模型,运行[MNIST手写数字识别目录](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)下的[mnist_v2.py](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/examples/model_inference/dense/mnist_v2.py)脚本,在终端执行`python mnist_v2.py`,会使用 PaddlePaddle 内置的 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)进行训练。训练好的模型默认保存在当前运行目录下的`models`目录中。
+调用C-API开发预测程序需要一个训练好的模型,运行[MNIST手写数字识别目录](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense)下的[mnist_v2.py](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/examples/model_inference/dense/mnist_v2.py)脚本,在终端执行`python mnist_v2.py`,会使用 PaddlePaddle 内置的 [MNIST 数据集](http://yann.lecun.com/exdb/mnist/)进行训练。训练好的模型默认保存在当前运行目录下的`models`目录中。
下面,我们将训练结束后存储下来的模型转换成预测模型。
@@ -48,7 +48,7 @@
dump_v2_config(predict, "trainer_config.bin", True)
```
- 对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)这个示例,[`mnist_v2.py`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense/mnist_v2.py)脚本集成了序列化神经网络结构的过程,可以直接运行 `python mnist_v2.py --task dump_config` 对神经网络结构进行序列化,结果会写入当前运行目录下的`trainer_config.bin`文件中。
+ 对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense)这个示例,[`mnist_v2.py`](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense/mnist_v2.py)脚本集成了序列化神经网络结构的过程,可以直接运行 `python mnist_v2.py --task dump_config` 对神经网络结构进行序列化,结果会写入当前运行目录下的`trainer_config.bin`文件中。
使用这种方式,需要**在运行时将神经网络的多个可学习参数放在同一个目录中**,C-API可以通过分别指定序列化后的网络结构文件和参数目录来加载训练好的模型。
@@ -68,7 +68,7 @@
merge_v2_model(net, param_file, output_file)
```
- 对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense)这个示例,可直接运行 `python` [merge_v2_model.py](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference/dense/merge_v2_model.py)。序列化结果会写入当前运行目录下的`output.paddle.model`文件中。使用这种方式,运行时C-API可以通过指定`output.paddle.model`文件的路径来加载预测模型。
+ 对[手写数字识别](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense)这个示例,可直接运行 `python` [merge_v2_model.py](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference/dense/merge_v2_model.py)。序列化结果会写入当前运行目录下的`output.paddle.model`文件中。使用这种方式,运行时C-API可以通过指定`output.paddle.model`文件的路径来加载预测模型。
#### 注意事项
1. 为使用C-API,在调用`dump_v2_config`序列化神经网络结构时,参数`binary`必须指定为`True`。
@@ -77,10 +77,10 @@
### 编写预测代码
-预测代码更多详细示例代码请参考[C-API使用示例](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/capi/examples/model_inference) 目录下的代码示例。这一节对图1中预测代码编写的5个步骤进行介绍和说明。
+预测代码更多详细示例代码请参考[C-API使用示例](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/capi/examples/model_inference) 目录下的代码示例。这一节对图1中预测代码编写的5个步骤进行介绍和说明。
#### step 1. 初始化PaddlePaddle运行环境
-第一步需调用[`paddle_init`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/main.h#L27) 初始化PaddlePaddle运行环境,该接口接受两个参数:参数的个数和参数列表。
+第一步需调用[`paddle_init`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/main.h#L27) 初始化PaddlePaddle运行环境,该接口接受两个参数:参数的个数和参数列表。
#### step2. 加载模型
@@ -88,8 +88,8 @@
概念上,在 PaddlePaddle 内部,一个GradientMachine类的对象管理着一组计算层(PaddlePaddle Layers)来完成前向和反向计算,并处理与之相关的所有细节。在调用C-API预测时,只需进行前向计算而无需调用反向计算。这篇文档之后部分会使用`gradient machine`来特指调用PaddlePaddle C-API创建的GradientMachine类的对象。每一个 `gradient machine` 都会管理维护一份训练好的模型,下面是C-API提供的,两种常用的模型加载方式:
-1. 调用[`paddle_gradient_machine_load_parameter_from_disk`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/gradient_machine.h#L61)接口,从磁盘加载预测模型。这时`gradient machine`会独立拥有一份训练好的模型;
-1. 调用[`paddle_gradient_machine_create_shared_param`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/gradient_machine.h#L88)接口,与其它`gradient machine`的共享已经加载的预测模型。这种情况多出现在使用多线程预测时,通过多个线程共享同一个模型来减少内存开销。可参考[此示例](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/examples/model_inference/multi_thread/main.c)。
+1. 调用[`paddle_gradient_machine_load_parameter_from_disk`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/gradient_machine.h#L61)接口,从磁盘加载预测模型。这时`gradient machine`会独立拥有一份训练好的模型;
+1. 调用[`paddle_gradient_machine_create_shared_param`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/gradient_machine.h#L88)接口,与其它`gradient machine`的共享已经加载的预测模型。这种情况多出现在使用多线程预测时,通过多个线程共享同一个模型来减少内存开销。可参考[此示例](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/examples/model_inference/multi_thread/main.c)。
- 注意事项
@@ -117,7 +117,7 @@ C-API支持的所有输入数据类型和他们的组织方式,请参考“输
#### step 4. 前向计算
-完成上述准备之后,通过调用 [`paddle_gradient_machine_forward`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/capi/gradient_machine.h#L73) 接口完成神经网络的前向计算。
+完成上述准备之后,通过调用 [`paddle_gradient_machine_forward`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/legacy/capi/gradient_machine.h#L73) 接口完成神经网络的前向计算。
#### step 5. 清理
diff --git a/paddle/contrib/inference/CMakeLists.txt b/paddle/contrib/inference/CMakeLists.txt
index c30eff5010..87173fc42a 100644
--- a/paddle/contrib/inference/CMakeLists.txt
+++ b/paddle/contrib/inference/CMakeLists.txt
@@ -45,14 +45,31 @@ endfunction(inference_api_test)
cc_library(paddle_inference_api
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
+if(NOT APPLE)
+ set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_inference_api.sym")
+ set_target_properties(paddle_inference_api PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
+endif()
# Here the shared library doesn't depend on other fluid libraries, or double free will occur.
cc_library(paddle_inference_api_shared SHARED
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc)
+add_dependencies(paddle_inference_api_shared ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
set_target_properties(paddle_inference_api_shared PROPERTIES OUTPUT_NAME paddle_inference_api)
+
if(NOT APPLE)
- set(LINK_FLAGS "-fPIC -fvisibility=hidden")
+ set(LINK_FLAGS "-Wl,--version-script ${CMAKE_CURRENT_SOURCE_DIR}/paddle_inference_api.map")
set_target_properties(paddle_inference_api_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
+ FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/check_symbol.cmake
+ "execute_process(COMMAND bash -c \"${CMAKE_CURRENT_SOURCE_DIR}/check_symbol.sh"
+ " ${CMAKE_CURRENT_BINARY_DIR}/libpaddle_inference_api.so\" RESULT_VARIABLE symbol_res)\n"
+ "if(NOT \"\${symbol_res}\" STREQUAL \"0\")\n"
+ " message(FATAL_ERROR \"Check symbol failed.\")\n"
+ "endif()\n")
+ add_custom_command(
+ OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol"
+ COMMAND ${CMAKE_COMMAND} -P "${CMAKE_CURRENT_BINARY_DIR}/check_symbol.cmake"
+ DEPENDS paddle_inference_api_shared)
+ add_custom_target(check_symbol ALL DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/.check_symbol")
endif()
cc_test(test_paddle_inference_api
diff --git a/paddle/contrib/inference/check_symbol.sh b/paddle/contrib/inference/check_symbol.sh
new file mode 100755
index 0000000000..6547ca1413
--- /dev/null
+++ b/paddle/contrib/inference/check_symbol.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+lib=$1
+if [ $# -ne 1 ]; then echo "No input library"; exit -1 ; fi
+
+num_paddle_syms=$(nm -D --defined-only ${lib} | grep paddle | wc -l)
+num_google_syms=$(nm -D --defined-only ${lib} | grep google | wc -l)
+
+if [ $num_paddle_syms -le 0 ]; then echo "Have no paddle symbols"; exit -1 ; fi
+if [ $num_google_syms -ge 1 ]; then echo "Have some google symbols"; exit -1 ; fi
+
+exit 0
diff --git a/paddle/contrib/inference/demo/CMakeLists.txt b/paddle/contrib/inference/demo/CMakeLists.txt
index ecece6fe34..2d501bf008 100644
--- a/paddle/contrib/inference/demo/CMakeLists.txt
+++ b/paddle/contrib/inference/demo/CMakeLists.txt
@@ -13,8 +13,6 @@
# limitations under the License.
#
-inference_api_test(simple_on_word2vec ARGS test_word2vec)
-
option(WITH_INFERENCE_DEMO "Compile with Inference demo" OFF)
if(NOT WITH_INFERENCE_DEMO)
return()
diff --git a/paddle/contrib/inference/demo_ci/CMakeLists.txt b/paddle/contrib/inference/demo_ci/CMakeLists.txt
new file mode 100644
index 0000000000..789bff7f23
--- /dev/null
+++ b/paddle/contrib/inference/demo_ci/CMakeLists.txt
@@ -0,0 +1,77 @@
+cmake_minimum_required(VERSION 3.0)
+
+project(cpp_inference_demo CXX C)
+
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
+
+if(NOT DEFINED PADDLE_LIB)
+ message(FATAL_ERROR "please set PADDLE_LIB with -DPADDLE_LIB=/path/paddle/lib")
+endif()
+if(NOT DEFINED DEMO_NAME)
+ message(FATAL_ERROR "please set DEMO_NAME with -DDEMO_NAME=demo_name")
+endif()
+
+option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON)
+option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF)
+option(WITH_STATIC_LIB "Compile demo with static/shared library, default use static." ON)
+
+if(WITH_GPU)
+ set(CUDA_LIB "/usr/local/cuda/lib64/" CACHE STRING "CUDA Library")
+endif()
+
+include_directories("${PADDLE_LIB}")
+include_directories("${PADDLE_LIB}/third_party/install/protobuf/include")
+include_directories("${PADDLE_LIB}/third_party/install/glog/include")
+include_directories("${PADDLE_LIB}/third_party/install/gflags/include")
+include_directories("${PADDLE_LIB}/third_party/install/snappy/include")
+include_directories("${PADDLE_LIB}/third_party/install/snappystream/include")
+include_directories("${PADDLE_LIB}/third_party/install/zlib/include")
+
+include_directories("${PADDLE_LIB}/third_party/boost")
+include_directories("${PADDLE_LIB}/third_party/eigen3")
+
+link_directories("${PADDLE_LIB}/third_party/install/snappy/lib")
+link_directories("${PADDLE_LIB}/third_party/install/snappystream/lib")
+link_directories("${PADDLE_LIB}/third_party/install/protobuf/lib")
+link_directories("${PADDLE_LIB}/third_party/install/glog/lib")
+link_directories("${PADDLE_LIB}/third_party/install/gflags/lib")
+link_directories("${PADDLE_LIB}/third_party/install/zlib/lib")
+
+add_executable(${DEMO_NAME} ${DEMO_NAME}.cc)
+
+if(WITH_MKL)
+ include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
+ set(MATH_LIB ${PADDLE_LIB}/third_party/install/mklml/lib/libmklml_intel.so
+ ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5.so)
+ set(MKLDNN_PATH "${PADDLE_LIB}/third_party/install/mkldnn")
+ if(EXISTS ${MKLDNN_PATH})
+ include_directories("${MKLDNN_PATH}/include")
+ set(MKLDNN_LIB ${MKLDNN_PATH}/lib/libmkldnn.so.0)
+ endif()
+else()
+ set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas.a)
+endif()
+
+if(WITH_STATIC_LIB)
+ set(DEPS
+ "-Wl,--whole-archive"
+ ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.a
+ "-Wl,--no-whole-archive"
+ ${PADDLE_LIB}/contrib/inference/libpaddle_inference_api.a)
+else()
+ # Note: libpaddle_inference_api.so must put before libpaddle_fluid.so
+ set(DEPS
+ ${PADDLE_LIB}/contrib/inference/libpaddle_inference_api.so
+ ${PADDLE_LIB}/paddle/fluid/inference/libpaddle_fluid.so)
+endif()
+set(EXTERNAL_LIB "-lrt -ldl -lpthread")
+
+set(DEPS ${DEPS}
+ ${MATH_LIB} ${MKLDNN_LIB}
+ glog gflags protobuf snappystream snappy z
+ ${EXTERNAL_LIB})
+if(WITH_GPU)
+ set(DEPS ${DEPS} ${CUDA_LIB}/libcudart.so)
+endif()
+
+target_link_libraries(${DEMO_NAME} ${DEPS})
diff --git a/paddle/contrib/inference/demo_ci/run.sh b/paddle/contrib/inference/demo_ci/run.sh
new file mode 100755
index 0000000000..e3a7269af7
--- /dev/null
+++ b/paddle/contrib/inference/demo_ci/run.sh
@@ -0,0 +1,34 @@
+set -x
+PADDLE_ROOT=$1
+WITH_MKL=$2
+WITH_GPU=$3
+if [ $3 == "ON" ]; then
+ use_gpu_list='true false'
+else
+ use_gpu_list='false'
+fi
+
+mkdir -p build
+cd build
+
+for WITH_STATIC_LIB in false; do
+ rm -rf *
+ cmake .. -DPADDLE_LIB=${PADDLE_ROOT}/build/fluid_install_dir/ \
+ -DWITH_MKL=$WITH_MKL \
+ -DDEMO_NAME=simple_on_word2vec \
+ -DWITH_GPU=$WITH_GPU \
+ -DWITH_STATIC_LIB=$WITH_STATIC_LIB
+ make
+ for use_gpu in $use_gpu_list; do
+ ./simple_on_word2vec \
+ --dirname=${PADDLE_ROOT}/build/python/paddle/fluid/tests/book/word2vec.inference.model \
+ --use_gpu=$use_gpu
+ done
+done
+if [ $? -eq 0 ]; then
+ exit 0
+else
+ echo "inference demo runs fail."
+ exit 1
+fi
+set +x
diff --git a/paddle/contrib/inference/demo/simple_on_word2vec.cc b/paddle/contrib/inference/demo_ci/simple_on_word2vec.cc
similarity index 68%
rename from paddle/contrib/inference/demo/simple_on_word2vec.cc
rename to paddle/contrib/inference/demo_ci/simple_on_word2vec.cc
index c253014642..9713837f86 100644
--- a/paddle/contrib/inference/demo/simple_on_word2vec.cc
+++ b/paddle/contrib/inference/demo_ci/simple_on_word2vec.cc
@@ -16,21 +16,27 @@ limitations under the License. */
* This file contains a simple demo for how to take a model for inference.
*/
+#include
#include
-#include
#include
#include
-#include "paddle/contrib/inference/paddle_inference_api.h"
+#include "contrib/inference/paddle_inference_api.h"
+#include "paddle/fluid/platform/enforce.h"
+
+DEFINE_string(dirname, "", "Directory of the inference model.");
+DEFINE_bool(use_gpu, false, "Whether use gpu.");
namespace paddle {
namespace demo {
-DEFINE_string(dirname, "", "Directory of the inference model.");
-
void Main(bool use_gpu) {
//# 1. Create PaddlePredictor with a config.
NativeConfig config;
- config.model_dir = FLAGS_dirname + "word2vec.inference.model";
+ if (FLAGS_dirname.empty()) {
+ LOG(INFO) << "Usage: ./simple_on_word2vec --dirname=path/to/your/model";
+ exit(1);
+ }
+ config.model_dir = FLAGS_dirname;
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
@@ -54,12 +60,16 @@ void Main(bool use_gpu) {
CHECK(predictor->Run(slots, &outputs));
//# 4. Get output.
- ASSERT_EQ(outputs.size(), 1UL);
- LOG(INFO) << "output buffer size: " << outputs.front().data.length();
+ PADDLE_ENFORCE(outputs.size(), 1UL);
+ // Check the output buffer size and result of each tid.
+ PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
+ float result[5] = {
+ 0.00129761, 0.00151112, 0.000423564, 0.00108815, 0.000932706};
const size_t num_elements = outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
- LOG(INFO) << static_cast(outputs.front().data.data())[i];
+ PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i],
+ result[i]);
}
}
}
@@ -68,7 +78,7 @@ void MainThreads(int num_threads, bool use_gpu) {
// Multi-threads only support on CPU
// 0. Create PaddlePredictor with a config.
NativeConfig config;
- config.model_dir = FLAGS_dirname + "word2vec.inference.model";
+ config.model_dir = FLAGS_dirname;
config.use_gpu = use_gpu;
config.fraction_of_gpu_memory = 0.15;
config.device = 0;
@@ -94,14 +104,17 @@ void MainThreads(int num_threads, bool use_gpu) {
CHECK(predictor->Run(inputs, &outputs));
// 4. Get output.
- ASSERT_EQ(outputs.size(), 1UL);
- LOG(INFO) << "TID: " << tid << ", "
- << "output buffer size: " << outputs.front().data.length();
+ PADDLE_ENFORCE(outputs.size(), 1UL);
+ // Check the output buffer size and result of each tid.
+ PADDLE_ENFORCE(outputs.front().data.length(), 33168UL);
+ float result[5] = {
+ 0.00129761, 0.00151112, 0.000423564, 0.00108815, 0.000932706};
const size_t num_elements =
outputs.front().data.length() / sizeof(float);
// The outputs' buffers are in CPU memory.
for (size_t i = 0; i < std::min(5UL, num_elements); i++) {
- LOG(INFO) << static_cast(outputs.front().data.data())[i];
+ PADDLE_ENFORCE(static_cast(outputs.front().data.data())[i],
+ result[i]);
}
}
});
@@ -111,15 +124,18 @@ void MainThreads(int num_threads, bool use_gpu) {
}
}
-TEST(demo, word2vec_cpu) { Main(false /*use_gpu*/); }
-TEST(demo_multi_threads, word2vec_cpu_1) { MainThreads(1, false /*use_gpu*/); }
-TEST(demo_multi_threads, word2vec_cpu_4) { MainThreads(4, false /*use_gpu*/); }
-
-#ifdef PADDLE_WITH_CUDA
-TEST(demo, word2vec_gpu) { Main(true /*use_gpu*/); }
-TEST(demo_multi_threads, word2vec_gpu_1) { MainThreads(1, true /*use_gpu*/); }
-TEST(demo_multi_threads, word2vec_gpu_4) { MainThreads(4, true /*use_gpu*/); }
-#endif
-
} // namespace demo
} // namespace paddle
+
+int main(int argc, char** argv) {
+ google::ParseCommandLineFlags(&argc, &argv, true);
+ paddle::demo::Main(false /* use_gpu*/);
+ paddle::demo::MainThreads(1, false /* use_gpu*/);
+ paddle::demo::MainThreads(4, false /* use_gpu*/);
+ if (FLAGS_use_gpu) {
+ paddle::demo::Main(true /*use_gpu*/);
+ paddle::demo::MainThreads(1, true /*use_gpu*/);
+ paddle::demo::MainThreads(4, true /*use_gpu*/);
+ }
+ return 0;
+}
diff --git a/paddle/contrib/inference/paddle_inference_api.map b/paddle/contrib/inference/paddle_inference_api.map
new file mode 100644
index 0000000000..5203784dc1
--- /dev/null
+++ b/paddle/contrib/inference/paddle_inference_api.map
@@ -0,0 +1,6 @@
+{
+ global:
+ *paddle*;
+ local:
+ *;
+};
diff --git a/paddle/contrib/inference/paddle_inference_api.sym b/paddle/contrib/inference/paddle_inference_api.sym
new file mode 100644
index 0000000000..ef2a04d788
--- /dev/null
+++ b/paddle/contrib/inference/paddle_inference_api.sym
@@ -0,0 +1 @@
+*paddle*
diff --git a/paddle/contrib/inference/test_paddle_inference_api_impl.cc b/paddle/contrib/inference/test_paddle_inference_api_impl.cc
index 88c4e665a3..c3649dcb96 100644
--- a/paddle/contrib/inference/test_paddle_inference_api_impl.cc
+++ b/paddle/contrib/inference/test_paddle_inference_api_impl.cc
@@ -249,7 +249,7 @@ void MainThreadsImageClassification(bool use_gpu) {
const size_t len = local_outputs[0].data.length();
float* data = static_cast(local_outputs[0].data.data());
float* ref_data = refs[tid].data();
- EXPECT_EQ(refs[tid].numel(), len / sizeof(float));
+ EXPECT_EQ((size_t)refs[tid].numel(), len / sizeof(float));
for (int i = 0; i < refs[tid].numel(); ++i) {
EXPECT_NEAR(ref_data[i], data[i], 1e-3);
}
diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index 397c9f7394..ec252929d5 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -27,6 +27,7 @@ cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
+cc_test(reader_test SRCS reader_test.cc DEPS reader)
cc_test(variable_test SRCS variable_test.cc)
diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
index eb4e7ec52f..1d80bab90f 100644
--- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc
@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
+#include
#include
#include
#include "paddle/fluid/framework/executor.h"
@@ -53,8 +54,14 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
}
}
}
+ std::vector fetch_data;
+ std::exception_ptr eptr;
+ try {
+ fetch_data = underlying_executor_->Run(fetch_tensors);
+ } catch (...) {
+ eptr = std::current_exception();
+ }
- auto fetch_data = underlying_executor_->Run(fetch_tensors);
drop_scope_counter_ += 1;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
@@ -69,7 +76,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
scope->DeleteScope(local_scope);
}
}
- return fetch_data;
+ if (eptr) {
+ std::rethrow_exception(eptr);
+ } else {
+ return fetch_data;
+ }
}
} // namespace details
} // namespace framework
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
index 99b10254a7..07097c7e75 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
@@ -78,6 +78,10 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
set.clear();
};
+ // Clean run context
+ run_op_futures_.clear();
+ exception_.reset();
+
// Step 3. Execution
while (!pending_vars.empty()) {
// 1. Run All Ready ops
@@ -96,16 +100,19 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto cur_ready_vars = ready_vars.PopAll(1, &timeout);
if (timeout) {
- std::lock_guard l(exception_mu_);
+ std::unique_lock l(exception_mu_);
if (exception_) {
+ l.unlock();
+ for (auto &run_op_future : run_op_futures_) {
+ run_op_future.wait();
+ }
+ l.lock();
std::exception *exp = exception_.get();
if (dynamic_cast(exp)) {
auto e = *static_cast(exp);
- exception_.reset();
throw e;
} else if (dynamic_cast(exp)) {
auto e = *static_cast(exp);
- exception_.reset();
throw e;
} else {
LOG(FATAL) << "Unknown exception.";
@@ -222,7 +229,7 @@ void ThreadedSSAGraphExecutor::RunOp(
}
};
if (pool_) {
- pool_->enqueue(op_run);
+ run_op_futures_.emplace_back(pool_->enqueue(op_run));
} else {
op_run();
}
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
index c69e0487e2..09973b7a72 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
@@ -15,6 +15,7 @@
#pragma once
#include
+#include
#include
#include
#include
@@ -77,6 +78,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
private:
ExecutionStrategy strategy_;
+ // use std::list because clear(), push_back, and for_each are O(1)
+ std::list> run_op_futures_;
};
} // namespace details
diff --git a/paddle/fluid/framework/op_info.cc b/paddle/fluid/framework/op_info.cc
index f1261dee03..af75baa5c4 100644
--- a/paddle/fluid/framework/op_info.cc
+++ b/paddle/fluid/framework/op_info.cc
@@ -21,8 +21,8 @@ namespace framework {
// a static local variable is already being initialized.
// https://stackoverflow.com/questions/11711920/how-to-implement-multithread-safe-singleton-in-c11-without-using-mutex
OpInfoMap& OpInfoMap::Instance() {
- static OpInfoMap* g_op_info_map = new OpInfoMap();
- return *g_op_info_map;
+ static OpInfoMap g_op_info_map;
+ return g_op_info_map;
}
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/reader.cc b/paddle/fluid/framework/reader.cc
index 0b36f1116d..5897d320a8 100644
--- a/paddle/fluid/framework/reader.cc
+++ b/paddle/fluid/framework/reader.cc
@@ -13,29 +13,61 @@
// limitations under the License.
#include "paddle/fluid/framework/reader.h"
+#include
namespace paddle {
namespace framework {
-ReaderBase::~ReaderBase() {}
-FileReader::FileReader(const std::vector &dims) : dims_(dims) {}
-
-void FileReader::ReadNext(std::vector *out) {
+void ReaderBase::ReadNext(std::vector *out) {
+ std::lock_guard lock(mu_);
+ PADDLE_ENFORCE_EQ(status_, ReaderStatus::kRunning);
ReadNextImpl(out);
- if (out->empty()) {
- return;
- }
+}
- PADDLE_ENFORCE_EQ(out->size(), dims_.size());
- for (size_t i = 0; i < dims_.size(); ++i) {
- auto &actual = (*out)[i].dims();
- auto &expect = dims_[i];
+void ReaderBase::InsertDecoratedReader(
+ const std::shared_ptr &decorated_reader) {
+ std::lock_guard guard(mu_);
+ decorated_readers_.emplace_back(decorated_reader);
+}
- PADDLE_ENFORCE_EQ(actual.size(), expect.size());
- for (int j = 0; j < actual.size(); ++j) {
- // PADDLE_ENFORCE(actual[i] == expect[i] || expect[i] == -1);
+std::unordered_set ReaderBase::GetEndPoints() {
+ std::unordered_set result;
+ std::deque queue;
+ queue.emplace_back(this);
+ while (!queue.empty()) { // BFS search
+ auto *front = queue.front();
+ queue.pop_front();
+ if (front->decorated_readers_.empty()) {
+ result.emplace(front);
+ } else {
+ for (auto &reader : front->decorated_readers_) {
+ if (auto *reader_ptr = reader.lock().get()) {
+ queue.emplace_back(reader_ptr);
+ }
+ }
}
}
+
+ return result;
}
+
+void ReaderBase::Shutdown() {
+ std::lock_guard lock(mu_);
+ if (status_ != ReaderStatus::kStopped) {
+ ShutdownImpl();
+ status_ = ReaderStatus::kStopped;
+ }
+}
+
+void ReaderBase::Start() {
+ std::lock_guard lock(mu_);
+ if (status_ != ReaderStatus::kRunning) {
+ StartImpl();
+ status_ = ReaderStatus::kRunning;
+ }
+}
+
+ReaderBase::~ReaderBase() { Shutdown(); }
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h
index 64d4ceab62..6c4432cb7a 100644
--- a/paddle/fluid/framework/reader.h
+++ b/paddle/fluid/framework/reader.h
@@ -15,6 +15,7 @@
#pragma once
#include
+#include
#include
#include "paddle/fluid/framework/ddim.h"
@@ -24,61 +25,116 @@
namespace paddle {
namespace framework {
+enum ReaderStatus { kRunning, kStopped };
+
class ReaderBase {
public:
- virtual void ReadNext(std::vector* out) = 0;
+ void ReadNext(std::vector* out);
+
+ void Shutdown();
- virtual void ReInit() = 0;
+ void Start();
+
+ // Return the readers which are the end of decorating chain. Basically
+ // they are readers just before read op.
+ std::unordered_set GetEndPoints();
virtual ~ReaderBase();
+
+ protected:
+ virtual void ReadNextImpl(std::vector* out) = 0;
+
+ virtual void ShutdownImpl() {}
+
+ virtual void StartImpl() {}
+
+ ReaderStatus status_{kRunning};
+
+ mutable std::mutex mu_;
+
+ private:
+ friend class DecoratedReader;
+ // These methods can be only invoked inside DecoratedReader to record the
+ // decorating chain.
+ void InsertDecoratedReader(
+ const std::shared_ptr& decorated_reader);
+ // A set of which readers that decorated this reader.
+ std::vector> decorated_readers_;
};
-class DecoratedReader : public ReaderBase {
+class DecoratedReader : public ReaderBase,
+ public std::enable_shared_from_this {
public:
explicit DecoratedReader(const std::shared_ptr& reader)
: ReaderBase(), reader_(reader) {
PADDLE_ENFORCE_NOT_NULL(reader_);
}
- void ReInit() override { reader_->ReInit(); }
+ void RegisterDecorateChain() {
+ reader_->InsertDecoratedReader(shared_from_this());
+ }
protected:
- std::shared_ptr reader_;
-};
-
-class FileReader : public ReaderBase {
- public:
- explicit FileReader(const std::vector& dims);
-
- void ReadNext(std::vector* out) override;
+ void ShutdownImpl() override { reader_->Shutdown(); }
- protected:
- virtual void ReadNextImpl(std::vector* out) = 0;
+ void StartImpl() override { reader_->Start(); }
- private:
- std::vector dims_;
+ std::shared_ptr reader_;
};
+// FileReader is just a conceptual class.
+class FileReader : public ReaderBase {};
+
// The ReaderHolder is used as reader' unified wrapper,
// making it easier to access different type reader in Variables.
class ReaderHolder {
public:
- void Reset(ReaderBase* reader) { reader_.reset(reader); }
+ template
+ void Reset(const std::shared_ptr& reader) {
+ auto reader_base = std::dynamic_pointer_cast(reader);
+ PADDLE_ENFORCE_NOT_NULL(reader_base);
+ reader_ = reader_base;
+ }
- std::shared_ptr Get() const { return reader_; }
+ const std::shared_ptr& Get() const { return reader_; }
void ReadNext(std::vector* out) {
PADDLE_ENFORCE_NOT_NULL(reader_);
reader_->ReadNext(out);
}
- void ReInit() {
+
+ void ResetAll() {
+ auto end_readers = reader_->GetEndPoints();
+ for (auto* reader : end_readers) {
+ reader->Shutdown();
+ }
+ for (auto* reader : end_readers) {
+ reader->Start();
+ }
+ }
+
+ void Shutdown() {
+ PADDLE_ENFORCE_NOT_NULL(reader_);
+ reader_->Shutdown();
+ }
+
+ void Start() {
PADDLE_ENFORCE_NOT_NULL(reader_);
- reader_->ReInit();
+ reader_->Start();
}
+ operator const std::shared_ptr&() const { return this->reader_; }
+
private:
std::shared_ptr reader_;
};
+template
+inline std::shared_ptr MakeDecoratedReader(ARGS&&... args) {
+ std::shared_ptr reader(new T(std::forward(args)...));
+ reader->RegisterDecorateChain();
+ return reader;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/reader_test.cc b/paddle/fluid/framework/reader_test.cc
new file mode 100644
index 0000000000..f0d07cb7c1
--- /dev/null
+++ b/paddle/fluid/framework/reader_test.cc
@@ -0,0 +1,52 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/reader.h"
+#include
+#include "gtest/gtest.h"
+
+class StubDecoratedReader : public paddle::framework::DecoratedReader {
+ public:
+ explicit StubDecoratedReader(const std::shared_ptr &reader)
+ : DecoratedReader(reader) {}
+
+ void ReadNextImpl(std::vector *out) override {}
+};
+
+class StubRootReader : public paddle::framework::ReaderBase {
+ public:
+ void ReadNextImpl(std::vector *out) override {}
+};
+
+TEST(READER, decorate_chain) {
+ auto root = std::make_shared();
+ auto end_point1 =
+ paddle::framework::MakeDecoratedReader(root);
+ auto end_point2 =
+ paddle::framework::MakeDecoratedReader(root);
+
+ {
+ auto endpoints = root->GetEndPoints();
+ ASSERT_EQ(endpoints.size(), 2U);
+ ASSERT_NE(endpoints.count(end_point1.get()), 0);
+ ASSERT_NE(endpoints.count(end_point2.get()), 0);
+ }
+
+ {
+ auto end_point3 =
+ paddle::framework::MakeDecoratedReader(root);
+ ASSERT_EQ(root->GetEndPoints().size(), 3U);
+ }
+ { ASSERT_EQ(root->GetEndPoints().size(), 2U); }
+}
diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt
index 1895aea7f9..b1c33c3415 100644
--- a/paddle/fluid/inference/CMakeLists.txt
+++ b/paddle/fluid/inference/CMakeLists.txt
@@ -13,6 +13,12 @@ endif()
# Create static library
cc_library(paddle_fluid DEPS ${fluid_modules} paddle_fluid_api)
+if(NOT APPLE)
+ # TODO(liuyiqu: Temporarily disable the link flag because it is not support on Mac.
+ set(LINK_FLAGS "-Wl,--retain-symbols-file ${CMAKE_CURRENT_SOURCE_DIR}/paddle_fluid.sym")
+ set_target_properties(paddle_fluid PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
+endif()
+
# Create shared library
cc_library(paddle_fluid_shared SHARED
SRCS io.cc
diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc
index d09bf3ed16..bd24e8a7d9 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph.cc
+++ b/paddle/fluid/inference/analysis/data_flow_graph.cc
@@ -90,6 +90,20 @@ std::string DataFlowGraph::DotString() const {
return dot.Build();
}
+std::string DataFlowGraph::HumanReadableInfo(bool show_values,
+ bool show_functions) const {
+ std::stringstream values, functions;
+ for (auto &n : nodes.nodes()) {
+ if (show_values && n->IsValue()) {
+ values << n->repr() << "\n";
+ }
+ if (show_functions && n->IsFunction()) {
+ functions << n->repr() << "\n";
+ }
+ }
+ return "Values:\n" + values.str() + "\n\n" + "Functions:\n" + functions.str();
+}
+
//
// NodesBFSIterator
//
@@ -146,7 +160,7 @@ bool GraphTraits::NodesBFSIterator::operator==(
if ((!queue_.empty()) && (!other.queue_.empty())) {
return queue_.front() == other.queue_.front() &&
visited_.size() == other.visited_.size(); // here need to check the
- // equality of queue and
+ // equality of queue and
// visited. Just a light but week implementation.
}
return false;
@@ -208,6 +222,76 @@ Node *GraphTraits::NodesDFSIterator::operator->() {
return stack_.top();
}
+GraphTraits::NodesTSIterator::NodesTSIterator(
+ const std::vector &source) {
+ PADDLE_ENFORCE(!source.empty(),
+ "Start points of topological sorting should not be empty!");
+ std::unordered_set visited;
+ std::unordered_set to_visit{source.begin(), source.end()};
+
+ std::vector inlink_visited;
+ while (!to_visit.empty()) {
+ std::vector queue(to_visit.begin(), to_visit.end());
+ for (auto *p : queue) {
+ inlink_visited.clear();
+
+ std::copy_if(p->inlinks.begin(), p->inlinks.end(),
+ std::back_inserter(inlink_visited),
+ [&](Node *x) { return visited.count(x); });
+
+ if (inlink_visited.size() == p->inlinks.size()) {
+ sorted_.push_back(p);
+ for (auto *_ : p->outlinks) {
+ if (!visited.count(_)) {
+ to_visit.insert(_);
+ }
+ }
+
+ to_visit.erase(p);
+ visited.insert(p);
+ }
+ }
+ }
+}
+
+GraphTraits::NodesTSIterator::NodesTSIterator(
+ const paddle::inference::analysis::GraphTraits<
+ DataFlowGraph>::NodesTSIterator &other)
+ : sorted_(other.sorted_), cursor_(other.cursor_) {}
+
+Node &GraphTraits::NodesTSIterator::operator*() {
+ PADDLE_ENFORCE_LT(cursor_, sorted_.size());
+ return *sorted_[cursor_];
+}
+
+paddle::inference::analysis::GraphTraits::NodesTSIterator
+ &GraphTraits::NodesTSIterator::operator++() {
+ if (++cursor_ >= sorted_.size()) {
+ sorted_.clear();
+ cursor_ = 0;
+ }
+ return *this;
+}
+paddle::inference::analysis::GraphTraits::NodesTSIterator &
+GraphTraits::NodesTSIterator::operator=(
+ const paddle::inference::analysis::GraphTraits<
+ DataFlowGraph>::NodesTSIterator &other) {
+ cursor_ = other.cursor_;
+ sorted_ = other.sorted_;
+ return *this;
+}
+
+bool GraphTraits::NodesTSIterator::operator==(
+ const paddle::inference::analysis::GraphTraits<
+ DataFlowGraph>::NodesTSIterator &other) {
+ return sorted_ == other.sorted_ && cursor_ == other.cursor_;
+}
+
+Node *GraphTraits::NodesTSIterator::operator->() {
+ PADDLE_ENFORCE_LT(cursor_, sorted_.size());
+ return sorted_[cursor_];
+}
+
} // namespace analysis
} // namespace inference
} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h
index a4fefc83e0..5dd914d197 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph.h
+++ b/paddle/fluid/inference/analysis/data_flow_graph.h
@@ -48,6 +48,9 @@ struct DataFlowGraph {
// Output a DOT graph file for debug.
std::string DotString() const;
+ std::string HumanReadableInfo(bool show_values = true,
+ bool show_functions = true) const;
+
private:
// Remove duplicate edges and so on.
void Clean();
@@ -107,6 +110,32 @@ struct GraphTraits {
std::unordered_set visited_;
};
+ // Topological sorting iterator on nodes.
+ struct NodesTSIterator
+ : public std::iterator {
+ NodesTSIterator() = default;
+ explicit NodesTSIterator(const std::vector &source);
+ NodesTSIterator(NodesTSIterator &&other)
+ : sorted_(std::move(other.sorted_)), cursor_(other.cursor_) {
+ other.cursor_ = 0;
+ }
+ NodesTSIterator(const NodesTSIterator &other);
+
+ Node &operator*();
+ NodesTSIterator &operator++();
+ // TODO(Superjomn) current implementation just compare the first
+ // element, need to compare the graph and all the elements in the queue and
+ // set.
+ NodesTSIterator &operator=(const NodesTSIterator &other);
+ bool operator==(const NodesTSIterator &other);
+ bool operator!=(const NodesTSIterator &other) { return !(*this == other); }
+ Node *operator->();
+
+ private:
+ std::vector sorted_;
+ int cursor_{0};
+ };
+
explicit GraphTraits(DataFlowGraph *graph) : graph_(graph) {}
// default use BFS to visit the nodes.
@@ -119,17 +148,24 @@ struct GraphTraits {
iterator_range nodes_in_DFS() {
return iterator_range(nodes_dfs_begin(), nodes_dfs_end());
}
+ iterator_range nodes_in_TS() {
+ return iterator_range(nodes_ts_begin(), nodes_ts_end());
+ }
private:
NodesBFSIterator nodes_bfs_begin() {
return NodesBFSIterator(graph_->inputs);
}
NodesBFSIterator nodes_bfs_end() { return NodesBFSIterator(); }
+
NodesDFSIterator nodes_dfs_begin() {
return NodesDFSIterator(graph_->inputs);
}
NodesDFSIterator nodes_dfs_end() { return NodesDFSIterator(); }
+ NodesTSIterator nodes_ts_begin() { return NodesTSIterator(graph_->inputs); }
+ NodesTSIterator nodes_ts_end() { return NodesTSIterator(); }
+
private:
DataFlowGraph *graph_;
};
diff --git a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc
index 9d7cceeb65..7912f8d7f1 100644
--- a/paddle/fluid/inference/analysis/data_flow_graph_tester.cc
+++ b/paddle/fluid/inference/analysis/data_flow_graph_tester.cc
@@ -24,11 +24,11 @@ TEST(DataFlowGraph, BFS) {
auto dfg = ProgramDescToDFG(desc);
dfg.Build();
- for (auto* in : dfg.inputs) {
+ for (auto *in : dfg.inputs) {
LOG(INFO) << "inputs: " << in->name() << " "
<< static_cast(in->type());
}
- for (auto* out : dfg.outputs) {
+ for (auto *out : dfg.outputs) {
LOG(INFO) << "outputs: " << out->name() << " "
<< static_cast(out->type());
}
@@ -57,6 +57,71 @@ TEST(DataFlowGraph, DFS) {
ASSERT_EQ(count, dfg.nodes.size());
}
+// Topological sorting.
+/*
+ * Graph topology
+ * inputs: 0, 1, 2
+ * 0 -> 4
+ * 0 -> 5
+ * 1 -> 6
+ * 2 -> 7
+ * 4 -> 5
+ * 4 -> 7
+ * 4 -> 3
+ * 7 -> 3
+ */
+TEST(DataFlowGraph, TS) {
+ DataFlowGraph graph;
+
+ for (int i = 0; i < 8; i++) {
+ auto *node = graph.nodes.Create(Node::Type::kValue);
+ node->SetName("node-" + std::to_string(i));
+ }
+
+ auto add_link = [&](int i, int j) {
+ Node *source = graph.nodes.GetMutable(i);
+ Node *target = graph.nodes.GetMutable(j);
+ target->inlinks.push_back(source);
+ source->outlinks.push_back(target);
+ };
+
+ graph.inputs.push_back(graph.nodes.GetMutable(0));
+ graph.inputs.push_back(graph.nodes.GetMutable(1));
+ graph.inputs.push_back(graph.nodes.GetMutable(2));
+
+ add_link(0, 4);
+ add_link(0, 5);
+ add_link(1, 6);
+ add_link(2, 7);
+ add_link(4, 5);
+ add_link(4, 7);
+ add_link(4, 3);
+ add_link(7, 3);
+
+ auto its = GraphTraits(&graph).nodes_in_TS();
+ std::vector sorted_ids;
+ for (auto it = its.begin(); it != its.end(); ++it) {
+ LOG(INFO) << it->name();
+ sorted_ids.push_back(it->id());
+ }
+
+ // Assert a occurs prior to b in the sorted_ids.
+ auto assert_positive_sequence_pair = [&](int a, int b) {
+ auto a_offset = std::find(sorted_ids.begin(), sorted_ids.end(), a);
+ auto b_offset = std::find(sorted_ids.begin(), sorted_ids.end(), b);
+ ASSERT_LT(a_offset, b_offset);
+ };
+
+ assert_positive_sequence_pair(2, 7);
+ assert_positive_sequence_pair(7, 3);
+ assert_positive_sequence_pair(4, 3);
+ assert_positive_sequence_pair(0, 4);
+ assert_positive_sequence_pair(0, 5);
+ assert_positive_sequence_pair(1, 6);
+ assert_positive_sequence_pair(4, 5);
+ assert_positive_sequence_pair(4, 7);
+}
+
} // namespace analysis
} // namespace inference
} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
index cfbbc284e4..cbca5abdd5 100644
--- a/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
+++ b/paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass_tester.cc
@@ -27,7 +27,7 @@ TEST_F(DFG_Tester, Init) {
DataFlowGraph graph;
pass.Run(&graph);
// Analysis is sensitive to ProgramDesc, careful to change the original model.
- ASSERT_EQ(graph.nodes.size(), 37);
+ ASSERT_EQ(graph.nodes.size(), 37UL);
pass.Finalize();
LOG(INFO) << '\n' << graph.DotString();
}
diff --git a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
index 8134494f8b..67dd4da54b 100644
--- a/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
+++ b/paddle/fluid/inference/analysis/subgraph_splitter_tester.cc
@@ -82,7 +82,7 @@ TEST_F(DFG_Tester, Fuse) {
// At least one nodes should be deleted.
ASSERT_EQ(dfg.nodes.size(), count0 + 1); // added a new FunctionBlock
- ASSERT_EQ(6UL, count1);
+ ASSERT_EQ(6, count1);
}
} // namespace analysis
diff --git a/paddle/fluid/inference/paddle_fluid.sym b/paddle/fluid/inference/paddle_fluid.sym
new file mode 100644
index 0000000000..ef2a04d788
--- /dev/null
+++ b/paddle/fluid/inference/paddle_fluid.sym
@@ -0,0 +1 @@
+*paddle*
diff --git a/paddle/fluid/memory/detail/buddy_allocator.cc b/paddle/fluid/memory/detail/buddy_allocator.cc
index 4194ba1979..01a8501dd4 100644
--- a/paddle/fluid/memory/detail/buddy_allocator.cc
+++ b/paddle/fluid/memory/detail/buddy_allocator.cc
@@ -19,8 +19,9 @@ namespace paddle {
namespace memory {
namespace detail {
-BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator,
- size_t min_chunk_size, size_t max_chunk_size)
+BuddyAllocator::BuddyAllocator(
+ std::unique_ptr system_allocator, size_t min_chunk_size,
+ size_t max_chunk_size)
: min_chunk_size_(min_chunk_size),
max_chunk_size_(max_chunk_size),
cache_(system_allocator->UseGpu()),
diff --git a/paddle/fluid/memory/detail/buddy_allocator.h b/paddle/fluid/memory/detail/buddy_allocator.h
index 2f39d774d6..f0c83efc23 100644
--- a/paddle/fluid/memory/detail/buddy_allocator.h
+++ b/paddle/fluid/memory/detail/buddy_allocator.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include // NOLINT
#include
#include
@@ -32,8 +33,8 @@ namespace detail {
class BuddyAllocator {
public:
- BuddyAllocator(SystemAllocator* system_allocator, size_t min_chunk_size,
- size_t max_chunk_size);
+ BuddyAllocator(std::unique_ptr system_allocator,
+ size_t min_chunk_size, size_t max_chunk_size);
~BuddyAllocator();
@@ -103,7 +104,7 @@ class BuddyAllocator {
private:
/*! Allocate CPU/GPU memory from system */
- SystemAllocator* system_allocator_;
+ std::unique_ptr system_allocator_;
std::mutex mutex_;
};
diff --git a/paddle/fluid/memory/malloc.cc b/paddle/fluid/memory/malloc.cc
index bd98ed8189..7c800b3c16 100644
--- a/paddle/fluid/memory/malloc.cc
+++ b/paddle/fluid/memory/malloc.cc
@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
+#include
+
#include "paddle/fluid/memory/malloc.h"
#include "glog/logging.h"
@@ -34,12 +36,15 @@ namespace memory {
using BuddyAllocator = detail::BuddyAllocator;
BuddyAllocator* GetCPUBuddyAllocator() {
+ static std::once_flag init_flag;
static detail::BuddyAllocator* a = nullptr;
- if (a == nullptr) {
- a = new detail::BuddyAllocator(new detail::CPUAllocator,
- platform::CpuMinChunkSize(),
- platform::CpuMaxChunkSize());
- }
+
+ std::call_once(init_flag, []() {
+ a = new detail::BuddyAllocator(
+ std::unique_ptr(new detail::CPUAllocator),
+ platform::CpuMinChunkSize(), platform::CpuMaxChunkSize());
+ });
+
return a;
}
@@ -68,27 +73,33 @@ size_t Used(platform::CPUPlace place) {
#ifdef PADDLE_WITH_CUDA
BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
- static BuddyAllocator** as = NULL;
- if (as == NULL) {
+ static std::once_flag init_flag;
+ static detail::BuddyAllocator** a_arr = nullptr;
+
+ std::call_once(init_flag, [gpu_id]() {
int gpu_num = platform::GetCUDADeviceCount();
- as = new BuddyAllocator*[gpu_num];
- for (int gpu = 0; gpu < gpu_num; gpu++) {
- as[gpu] = nullptr;
+ PADDLE_ENFORCE(gpu_id < gpu_num, "gpu_id:%d should < gpu_num:%d", gpu_id,
+ gpu_num);
+
+ a_arr = new BuddyAllocator*[gpu_num];
+ for (int i = 0; i < gpu_num; i++) {
+ a_arr[i] = nullptr;
+ platform::SetDeviceId(i);
+ a_arr[i] = new BuddyAllocator(
+ std::unique_ptr(new detail::GPUAllocator(i)),
+ platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
+
+ VLOG(10) << "\n\nNOTE: each GPU device use "
+ << FLAGS_fraction_of_gpu_memory_to_use * 100
+ << "% of GPU memory.\n"
+ << "You can set GFlags environment variable '"
+ << "FLAGS_fraction_of_gpu_memory_to_use"
+ << "' to change the fraction of GPU usage.\n\n";
}
- }
+ });
+
platform::SetDeviceId(gpu_id);
- if (!as[gpu_id]) {
- as[gpu_id] = new BuddyAllocator(new detail::GPUAllocator(gpu_id),
- platform::GpuMinChunkSize(),
- platform::GpuMaxChunkSize());
- VLOG(10) << "\n\nNOTE: each GPU device use "
- << FLAGS_fraction_of_gpu_memory_to_use * 100
- << "% of GPU memory.\n"
- << "You can set GFlags environment variable '"
- << "FLAGS_fraction_of_gpu_memory_to_use"
- << "' to change the fraction of GPU usage.\n\n";
- }
- return as[gpu_id];
+ return a_arr[gpu_id];
}
template <>
@@ -125,12 +136,16 @@ void Free(platform::CUDAPlace place, void* p) {
}
BuddyAllocator* GetCUDAPinnedBuddyAllocator() {
- static BuddyAllocator* ba = NULL;
- if (ba == NULL) {
- ba = new BuddyAllocator(new detail::CUDAPinnedAllocator,
+ static std::once_flag init_flag;
+ static BuddyAllocator* ba = nullptr;
+
+ std::call_once(init_flag, []() {
+ ba = new BuddyAllocator(std::unique_ptr(
+ new detail::CUDAPinnedAllocator),
platform::CUDAPinnedMinChunkSize(),
platform::CUDAPinnedMaxChunkSize());
- }
+ });
+
return ba;
}
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index ab1d214333..bc07bbe67e 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -265,6 +265,8 @@ op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor)
+op_library(unsqueeze_op DEPS reshape_op)
+op_library(squeeze_op DEPS reshape_op)
if (WITH_GPU)
op_library(conv_op DEPS vol2col depthwise_conv im2col)
diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc
index 693bf973c2..5912a1a17c 100644
--- a/paddle/fluid/operators/batch_norm_op.cc
+++ b/paddle/fluid/operators/batch_norm_op.cc
@@ -216,6 +216,18 @@ class BatchNormKernel
saved_mean_e.setZero();
saved_variance_e.setZero();
+ EigenVectorArrayMap running_mean_arr(
+ mean_out->mutable_data(ctx.GetPlace()), C);
+ EigenVectorArrayMap running_var_arr(
+ variance_out->mutable_data(ctx.GetPlace()), C);
+
+ if ((N * sample_size) == 1) {
+ LOG(WARNING) << "Only 1 element in normalization dimension, "
+ << "we skip the batch norm calculation, let y = x.";
+ framework::TensorCopySync(*x, ctx.GetPlace(), y);
+ return;
+ }
+
switch (data_layout) {
case DataLayout::kNCHW: {
ConstEigenArrayMap x_arr(x->data(), sample_size, N * C);
@@ -247,10 +259,6 @@ class BatchNormKernel
PADDLE_THROW("Unknown storage order: %s", data_layout_str);
}
- EigenVectorArrayMap running_mean_arr(
- mean_out->mutable_data(ctx.GetPlace()), C);
- EigenVectorArrayMap running_var_arr(
- variance_out->mutable_data(ctx.GetPlace()), C);
running_mean_arr =
running_mean_arr * momentum + saved_mean_e * (1. - momentum);
running_var_arr =
@@ -427,6 +435,11 @@ class BatchNormGradKernel
d_bias_arr.setZero();
d_scale_arr.setZero();
+ if ((N * sample_size) == 1) {
+ framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
+ return;
+ }
+
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size);
switch (data_layout) {
diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc
index 550dd32d36..ca6cd86693 100644
--- a/paddle/fluid/operators/batch_norm_op.cu.cc
+++ b/paddle/fluid/operators/batch_norm_op.cu.cc
@@ -72,6 +72,9 @@ class BatchNormKernel
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
+ auto *y = ctx.Output("Y");
+ y->mutable_data(ctx.GetPlace());
+
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t data_desc_;
cudnnTensorDescriptor_t bn_param_desc_;
@@ -93,7 +96,7 @@ class BatchNormKernel
mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif
- VLOG(1) << "Setting descriptors.";
+ VLOG(3) << "Setting descriptors.";
std::vector dims;
std::vector strides;
if (data_layout == DataLayout::kNCHW) {
@@ -113,11 +116,6 @@ class BatchNormKernel
const auto *scale = ctx.Input("Scale");
const auto *bias = ctx.Input("Bias");
- auto *y = ctx.Output("Y");
-
- // alloc memory
- y->mutable_data(ctx.GetPlace());
-
auto &dev_ctx = ctx.template device_context();
auto handle = dev_ctx.cudnn_handle();
@@ -162,22 +160,28 @@ class BatchNormKernel
functor(dev_ctx, saved_mean, static_cast>(0));
functor(dev_ctx, saved_variance, static_cast>(0));
- double this_factor = 1. - momentum;
-
- CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
- handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(),
- data_desc_, x->template data(), data_desc_,
- y->template mutable_data(ctx.GetPlace()), bn_param_desc_,
- scale->template data>(),
- bias->template data>(), this_factor,
- mean_out->template mutable_data>(
- ctx.GetPlace()),
- variance_out->template mutable_data>(
- ctx.GetPlace()),
- epsilon, saved_mean->template mutable_data>(
- ctx.GetPlace()),
- saved_variance->template mutable_data>(
- ctx.GetPlace())));
+ if ((N * H * W * D) == 1) {
+ LOG(WARNING) << "Only 1 element in normalization dimension, "
+ << "we skip the batch norm calculation, let y = x.";
+ framework::TensorCopySync(*x, ctx.GetPlace(), y);
+ } else {
+ double this_factor = 1. - momentum;
+
+ CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
+ handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(),
+ data_desc_, x->template data(), data_desc_,
+ y->template mutable_data(ctx.GetPlace()), bn_param_desc_,
+ scale->template data>(),
+ bias->template data>(), this_factor,
+ mean_out->template mutable_data>(
+ ctx.GetPlace()),
+ variance_out->template mutable_data>(
+ ctx.GetPlace()),
+ epsilon, saved_mean->template mutable_data>(
+ ctx.GetPlace()),
+ saved_variance->template mutable_data>(
+ ctx.GetPlace())));
+ }
}
// clean when exit.
@@ -209,6 +213,25 @@ class BatchNormGradKernel
int N, C, H, W, D;
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
+ // init output
+ auto *d_x = ctx.Output(framework::GradVarName("X"));
+ auto *d_scale = ctx.Output(framework::GradVarName("Scale"));
+ auto *d_bias = ctx.Output(framework::GradVarName("Bias"));
+
+ d_x->mutable_data(ctx.GetPlace());
+ d_scale->mutable_data(ctx.GetPlace());
+ d_bias->mutable_data(ctx.GetPlace());
+
+ auto &dev_ctx = ctx.template device_context();
+ if ((N * H * W * D) == 1) {
+ framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
+ math::SetConstant>
+ functor;
+ functor(dev_ctx, d_scale, static_cast>(0));
+ functor(dev_ctx, d_bias, static_cast>(0));
+ return;
+ }
+
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
@@ -247,21 +270,11 @@ class BatchNormGradKernel
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
- // init output
- auto *d_x = ctx.Output(framework::GradVarName("X"));
- auto *d_scale = ctx.Output(framework::GradVarName("Scale"));
- auto *d_bias = ctx.Output(framework::GradVarName("Bias"));
-
- d_x->mutable_data(ctx.GetPlace());
- d_scale->mutable_data(ctx.GetPlace());
- d_bias->mutable_data(ctx.GetPlace());
-
const auto *saved_mean = ctx.Input("SavedMean");
const auto *saved_var = ctx.Input("SavedVariance");
const void *saved_mean_data = saved_mean->template data();
const void *saved_var_data = saved_var->template data();
- auto &dev_ctx = ctx.template device_context();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(),
CudnnDataType::kZero(), CudnnDataType::kOne(),
diff --git a/paddle/fluid/operators/conditional_block_op.cc b/paddle/fluid/operators/conditional_block_op.cc
index 8cc1d94260..580fde7538 100644
--- a/paddle/fluid/operators/conditional_block_op.cc
+++ b/paddle/fluid/operators/conditional_block_op.cc
@@ -205,9 +205,10 @@ class ConditionalBlockGradInferShape : public framework::InferShapeBase {
context->SetOutputsDim(framework::GradVarName("Params"),
context->GetInputsDim("Params"));
}
- PADDLE_ENFORCE(context->HasOutputs(framework::GradVarName("X")));
- context->SetOutputsDim(framework::GradVarName("X"),
- context->GetInputsDim("X"));
+ if (context->HasOutputs(framework::GradVarName("X"))) {
+ context->SetOutputsDim(framework::GradVarName("X"),
+ context->GetInputsDim("X"));
+ }
}
};
diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc
index 6b06913d1c..5bfa1aaa69 100644
--- a/paddle/fluid/operators/conv_mkldnn_op.cc
+++ b/paddle/fluid/operators/conv_mkldnn_op.cc
@@ -29,6 +29,79 @@ using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
+class ConvMKLDNNHandler : public platform::MKLDNNHandler {
+ public:
+ ConvMKLDNNHandler(
+ std::shared_ptr conv_pd,
+ const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
+ const std::string& base_key)
+ : platform::MKLDNNHandler(dev_ctx, engine, base_key) {
+ conv_pd_ = conv_pd;
+ }
+
+ std::shared_ptr AcquireDstMemoryFromPrimitive(void* ptr) {
+ return this->AcquireMemoryFromPrimitive(conv_pd_->dst_primitive_desc(), ptr,
+ "@dst_mem_p");
+ }
+
+ std::shared_ptr AcquireSrcMemoryFromPrimitive(
+ const std::shared_ptr user_memory_p,
+ std::vector& pipeline) {
+ auto src_pd = conv_pd_->src_primitive_desc();
+ auto user_pd = user_memory_p->get_primitive_desc();
+ return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p",
+ pipeline);
+ }
+
+ std::shared_ptr AcquireWeightsMemoryFromPrimitive(
+ const std::shared_ptr user_weights_memory_p,
+ std::vector& pipeline) {
+ auto user_weights_pd = user_weights_memory_p->get_primitive_desc();
+ auto weights_pd = conv_pd_->weights_primitive_desc();
+ return this->AcquireMemory(weights_pd, user_weights_pd,
+ user_weights_memory_p, "@weights_mem_p",
+ pipeline);
+ }
+
+ std::shared_ptr AcquireConvolution(
+ std::shared_ptr src_memory_p,
+ std::shared_ptr weights_memory_p,
+ std::shared_ptr dst_memory_p) {
+ auto prim_key = key_ + "@conv_p";
+ auto prim_desc_key = key_ + "@conv_pd";
+ auto conv_p = std::static_pointer_cast(
+ dev_ctx_.GetBlob(prim_key));
+ PADDLE_ENFORCE((conv_p != nullptr) || (is_reusing_ == false),
+ "Fail to find convolution primitive in device context");
+ if (conv_p == nullptr) {
+ conv_p = std::make_shared(
+ *conv_pd_, *(src_memory_p), *(weights_memory_p.get()),
+ *(dst_memory_p.get()));
+
+ dev_ctx_.SetBlob(prim_key, conv_p);
+ } else {
+ is_reusing_ = true;
+ }
+ return conv_p;
+ }
+
+ // Generate keys for storing/retriving primitives for this operator
+ // TODO(jczaja): Make hashing function more optimial
+ static std::string GetHash(memory::dims& input_dims,
+ memory::dims& weights_dims,
+ std::vector& strides,
+ std::vector& paddings,
+ std::vector& dilations, int groups,
+ const std::string& suffix) {
+ return dims2str(input_dims) + dims2str(weights_dims) + dims2str(strides) +
+ dims2str(paddings) + dims2str(dilations) + std::to_string(groups) +
+ suffix;
+ }
+
+ private:
+ std::shared_ptr conv_pd_;
+};
+
template
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel {
public:
@@ -36,10 +109,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
- // Get unique name for index
- const std::string key = ctx.op().Output("Output");
- const std::string key_conv_pd = key + "@conv_pd";
-
auto& dev_ctx =
ctx.template device_context();
const auto& mkldnn_engine = dev_ctx.GetEngine();
@@ -80,68 +149,62 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel {
paddle::framework::vectorize2int(filter->dims());
std::vector dst_tz = paddle::framework::vectorize2int(output->dims());
- // create mkldnn memory from input tensors (data/weights)
- auto user_src_memory = memory(
- {{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine},
- to_void_cast(input_data));
- auto user_weights_memory =
- memory({{{weights_tz}, memory::data_type::f32, filter->format()},
- mkldnn_engine},
- to_void_cast(filter_data));
+ // Get unique name for storing MKLDNN primitives
+ const std::string key = ConvMKLDNNHandler::GetHash(
+ src_tz, weights_tz, strides, paddings, dilations, groups,
+ ctx.op().Output("Output"));
+ const std::string key_conv_pd = key + "@conv_pd";
+
+ std::vector pipeline;
+
+ auto user_src_md = platform::MKLDNNMemDesc(
+ {src_tz}, platform::MKLDNNGetDataType(), input->format());
+ auto user_weights_md = platform::MKLDNNMemDesc(
+ {weights_tz}, platform::MKLDNNGetDataType(), filter->format());
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
- auto src_md = platform::MKLDNNMemDesc(src_tz, memory::data_type::f32,
- memory::format::any);
+ auto src_md = platform::MKLDNNMemDesc(
+ src_tz, platform::MKLDNNGetDataType(), memory::format::any);
auto weights_md = platform::MKLDNNMemDesc(
- weights_tz, memory::data_type::f32, memory::format::any);
- auto dst_md = platform::MKLDNNMemDesc(dst_tz, memory::data_type::f32,
- memory::format::any);
+ weights_tz, platform::MKLDNNGetDataType(), memory::format::any);
+ auto dst_md = platform::MKLDNNMemDesc(
+ dst_tz, platform::MKLDNNGetDataType(), memory::format::any);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine);
+ // Save conv_pd/src_memory/weights_memory for backward pass
+ dev_ctx.SetBlob(key_conv_pd, conv_pd);
- // create reorder primitive if the input format is not the preferred one
- auto src_memory = user_src_memory;
- primitive reorder_src;
- bool is_src_reordered = false;
- if (memory::primitive_desc(conv_pd->src_primitive_desc()) !=
- user_src_memory.get_primitive_desc()) {
- src_memory = memory(conv_pd->src_primitive_desc());
- reorder_src = reorder(user_src_memory, src_memory);
- is_src_reordered = true;
- }
- auto weights_memory = user_weights_memory;
- primitive reorder_weights;
- bool is_weights_reordered = false;
- if (memory::primitive_desc(conv_pd->weights_primitive_desc()) !=
- user_weights_memory.get_primitive_desc()) {
- weights_memory = memory(conv_pd->weights_primitive_desc());
- reorder_weights = reorder(user_weights_memory, weights_memory);
- is_weights_reordered = true;
- }
+ ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
- // create memory primitive for conv dst
- auto dst_memory = memory(conv_pd->dst_primitive_desc(), output_data);
+ // create mkldnn memory from input tensors (data/weights)
+ auto user_src_memory_p =
+ handler.AcquireSrcMemory(user_src_md, to_void_cast(input_data));
+ auto user_weights_memory_p = handler.AcquireWeightsMemory(
+ user_weights_md, to_void_cast(filter_data));
+
+ // create reorder primitive if the input format is not the preferred one
+ auto src_memory_p =
+ handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
+ auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
+ user_weights_memory_p, pipeline);
+ auto dst_memory_p =
+ handler.AcquireDstMemoryFromPrimitive(to_void_cast(output_data));
// create convolution op primitive
- auto conv_prim = conv_fwd(*conv_pd, src_memory, weights_memory, dst_memory);
+ auto conv_p = handler.AcquireConvolution(src_memory_p, weights_memory_p,
+ dst_memory_p);
// push primitive to stream and wait until it's executed
- std::vector pipeline;
- if (is_src_reordered) pipeline.push_back(reorder_src);
- if (is_weights_reordered) pipeline.push_back(reorder_weights);
- pipeline.push_back(conv_prim);
+ pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
- // Save conv_pd/src_memory/weights_memory for backward pass
- dev_ctx.SetBlob(key_conv_pd, conv_pd);
-
output->set_layout(DataLayout::kMKLDNN);
- output->set_format(GetMKLDNNFormat(dst_memory));
+ output->set_format(GetMKLDNNFormat(*dst_memory_p));
}
private:
@@ -197,13 +260,10 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel {
if (!input_grad && !filter_grad) return;
- // Get an unique name from "argument" name of "Output" variable
- // This name will be used as key when saving info into device context
- const std::string key = ctx.op().Input("Output");
- const std::string key_conv_pd = key + "@conv_pd";
-
std::vector strides = ctx.Attr>("strides");
std::vector paddings = ctx.Attr>("paddings");
+ std::vector dilations = ctx.Attr>("dilations");
+ int groups = ctx.Attr("groups");
const T* input_data = input->data();
const T* filter_data = filter->data();
@@ -223,6 +283,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel {
paddle::framework::vectorize2int(filter->dims());
std::vector dst_tz = paddle::framework::vectorize2int(output->dims());
+ // Get an unique name from "argument" name of "Output" variable
+ // This name will be used as key when saving info into device context
+ const std::string key =
+ ConvMKLDNNHandler::GetHash(src_tz, weights_tz, strides, paddings,
+ dilations, groups, ctx.op().Input("Output"));
+
+ const std::string key_conv_pd = key + "@conv_pd";
+
// create mkldnn memory from input tensors (input/weights/output_grad)
auto user_src_memory = memory(
{{{src_tz}, memory::data_type::f32, input->format()}, mkldnn_engine},
diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc
index d5e095f9ca..a3bec3da45 100644
--- a/paddle/fluid/operators/cross_entropy_op.cc
+++ b/paddle/fluid/operators/cross_entropy_op.cc
@@ -124,8 +124,7 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
"Tensor with shape [N x D].");
AddOutput("Y",
"(Tensor, default Tensor), a 2-D tensor with shape "
- "[N x 1]. The cross entropy loss.")
- .Reuse("X");
+ "[N x 1]. The cross entropy loss.");
AddAttr("soft_label",
"(bool, default false), a flag indicating whether to "
"interpretate the given labels as soft labels.")
diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt
index 6d296ff7bf..a44d84cd7b 100644
--- a/paddle/fluid/operators/detection/CMakeLists.txt
+++ b/paddle/fluid/operators/detection/CMakeLists.txt
@@ -27,7 +27,8 @@ anchor_generator_op.cu)
detection_library(target_assign_op SRCS target_assign_op.cc
target_assign_op.cu)
detection_library(polygon_box_transform_op SRCS polygon_box_transform_op.cc
- polygon_box_transform_op.cu)
+polygon_box_transform_op.cu)
+detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
# Export local libraries to parent
set(DETECTION_LIBRARY ${LOCAL_DETECTION_LIBS} PARENT_SCOPE)
diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc
index 4e35c38e4e..b5cb6a724c 100644
--- a/paddle/fluid/operators/detection/prior_box_op.cc
+++ b/paddle/fluid/operators/detection/prior_box_op.cc
@@ -149,6 +149,13 @@ class PriorBoxOpMaker : public framework::OpProtoAndCheckerMaker {
"(float) "
"Prior boxes center offset.")
.SetDefault(0.5);
+ AddAttr(
+ "min_max_aspect_ratios_order",
+ "(bool) If set True, the output prior box is in order of"
+ "[min, max, aspect_ratios], which is consistent with Caffe."
+ "Please note, this order affects the weights order of convolution layer"
+ "followed by and does not affect the final detection results.")
+ .SetDefault(false);
AddComment(R"DOC(
Prior box operator
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm.
diff --git a/paddle/fluid/operators/detection/prior_box_op.cu b/paddle/fluid/operators/detection/prior_box_op.cu
index f67e6ca91c..1ea8cfc1d2 100644
--- a/paddle/fluid/operators/detection/prior_box_op.cu
+++ b/paddle/fluid/operators/detection/prior_box_op.cu
@@ -28,8 +28,8 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
const int im_width, const int as_num,
const T offset, const T step_width,
const T step_height, const T* min_sizes,
- const T* max_sizes, const int min_num,
- bool is_clip) {
+ const T* max_sizes, const int min_num, bool is_clip,
+ bool min_max_aspect_ratios_order) {
int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num;
int box_num = height * width * num_priors;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num;
@@ -44,14 +44,28 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height,
T min_size = min_sizes[m];
if (max_sizes) {
int s = p % (as_num + 1);
- if (s < as_num) {
- T ar = aspect_ratios[s];
- bw = min_size * sqrt(ar) / 2.;
- bh = min_size / sqrt(ar) / 2.;
+ if (!min_max_aspect_ratios_order) {
+ if (s < as_num) {
+ T ar = aspect_ratios[s];
+ bw = min_size * sqrt(ar) / 2.;
+ bh = min_size / sqrt(ar) / 2.;
+ } else {
+ T max_size = max_sizes[m];
+ bw = sqrt(min_size * max_size) / 2.;
+ bh = bw;
+ }
} else {
- T max_size = max_sizes[m];
- bw = sqrt(min_size * max_size) / 2.;
- bh = bw;
+ if (s == 0) {
+ bw = bh = min_size / 2.;
+ } else if (s == 1) {
+ T max_size = max_sizes[m];
+ bw = sqrt(min_size * max_size) / 2.;
+ bh = bw;
+ } else {
+ T ar = aspect_ratios[s - 1];
+ bw = min_size * sqrt(ar) / 2.;
+ bh = min_size / sqrt(ar) / 2.;
+ }
}
} else {
int s = p % as_num;
@@ -94,6 +108,8 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel {
auto variances = ctx.Attr>("variances");
auto flip = ctx.Attr("flip");
auto clip = ctx.Attr("clip");
+ auto min_max_aspect_ratios_order =
+ ctx.Attr("min_max_aspect_ratios_order");
std::vector aspect_ratios;
ExpandAspectRatios(input_aspect_ratio, flip, &aspect_ratios);
@@ -149,7 +165,7 @@ class PriorBoxOpCUDAKernel : public framework::OpKernel {
GenPriorBox<<>>(
boxes->data