diff --git a/AUTHORS.md b/AUTHORS.md
index 2f756c09bc..8bacd8b169 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -30,6 +30,7 @@
| NHZlX | Zhao-Long Xing |
| Noplz | Yuan Gao |
| pakchoi | Chuan-Jiang Song |
+| panyx0718 | Xin Pan |
| pengli09 | Peng Li |
| pkuyym | Ya-Ming Yang |
| QiJune | Jun Qi |
diff --git a/benchmark/fluid/machine_translation.py b/benchmark/fluid/machine_translation.py
index d7a421c109..adde5f21ac 100644
--- a/benchmark/fluid/machine_translation.py
+++ b/benchmark/fluid/machine_translation.py
@@ -21,7 +21,7 @@ import argparse
import time
import distutils.util
-import paddle.v2 as paddle
+import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.framework as framework
diff --git a/benchmark/fluid/mnist.py b/benchmark/fluid/mnist.py
index dc10ac2ec1..1e2185dfac 100644
--- a/benchmark/fluid/mnist.py
+++ b/benchmark/fluid/mnist.py
@@ -20,7 +20,7 @@ import numpy as np
import argparse
import time
-import paddle.v2 as paddle
+import paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
diff --git a/benchmark/fluid/resnet.py b/benchmark/fluid/resnet.py
index 1af5eaf6b4..831fa2c019 100644
--- a/benchmark/fluid/resnet.py
+++ b/benchmark/fluid/resnet.py
@@ -23,7 +23,7 @@ import time
import cProfile, pstats, StringIO
-import paddle.v2 as paddle
+import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.profiler as profiler
diff --git a/benchmark/fluid/stacked_dynamic_lstm.py b/benchmark/fluid/stacked_dynamic_lstm.py
index 5fcbdd64af..73bcc47b4d 100644
--- a/benchmark/fluid/stacked_dynamic_lstm.py
+++ b/benchmark/fluid/stacked_dynamic_lstm.py
@@ -23,10 +23,10 @@ import random
import time
import numpy
-import paddle.v2 as paddle
-import paddle.v2.dataset.imdb as imdb
+import paddle
+import paddle.dataset.imdb as imdb
import paddle.fluid as fluid
-from paddle.v2 import batch
+import paddle.batch as batch
import paddle.fluid.profiler as profiler
diff --git a/benchmark/fluid/vgg.py b/benchmark/fluid/vgg.py
index 9d990eff62..53e34e0cbd 100644
--- a/benchmark/fluid/vgg.py
+++ b/benchmark/fluid/vgg.py
@@ -17,7 +17,7 @@ from __future__ import print_function
import sys
import time
import numpy as np
-import paddle.v2 as paddle
+import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import argparse
diff --git a/doc/fluid/api/data/data_reader.rst b/doc/fluid/api/data/data_reader.rst
index d7c896a627..1a35d0bbc8 100644
--- a/doc/fluid/api/data/data_reader.rst
+++ b/doc/fluid/api/data/data_reader.rst
@@ -56,11 +56,11 @@ DataFeeder
Reader
======
-.. automodule:: paddle.v2.reader
+.. automodule:: paddle.reader
:members:
:noindex:
-.. automodule:: paddle.v2.reader.creator
+.. automodule:: paddle.reader.creator
:members:
:noindex:
diff --git a/doc/fluid/api/layers.rst b/doc/fluid/api/layers.rst
index 3790f09c84..ff3c9346a2 100644
--- a/doc/fluid/api/layers.rst
+++ b/doc/fluid/api/layers.rst
@@ -479,6 +479,13 @@ label_smooth
.. autofunction:: paddle.fluid.layers.label_smooth
:noindex:
+roi_pool
+---------
+
+.. autofunction:: paddle.fluid.layers.roi_pool
+ :noindex:
+
+
ops
===
@@ -820,3 +827,5 @@ topk
.. autofunction:: paddle.fluid.layers.topk
:noindex:
+
+
diff --git a/doc/fluid/design/data_type/float16.md b/doc/fluid/design/data_type/float16.md
index 1ea95ed6b5..844d2aafcf 100644
--- a/doc/fluid/design/data_type/float16.md
+++ b/doc/fluid/design/data_type/float16.md
@@ -3,7 +3,7 @@
## Why float16
Half precision (float16) is a binary floating-point format that occupies 16 bits in memory. float16 is half the size of traditional 32-bit single precision format (float) and has lower precision and smaller range.
-When high precision computation is not required, using float16 data type could potentially
+When high precision computation is not required (which is usually the case at least in the deep learning inference stage), using float16 data type could potentially
- reduce storage space, memory bandwidth, and power usages;
- increase the chance of data fitting into a smaller cache of lower latency;
@@ -12,7 +12,7 @@ When high precision computation is not required, using float16 data type could p
## Survey of current float16 support
A brief survey of float16 support on different compilers, hardwares, and libraries can be found below. Interested readers can refer to [link1](https://github.com/PaddlePaddle/Paddle/issues/4853) and [link2](https://github.com/Xreki/Xreki.github.io/blob/master/multi_data_types_in_dl_framework/ppt/float16_and_quantized_type.md) for more info.
-The goal of float16 is to serve as a key for the executor to find and run the correct version of compute method specialized for float16 in operator kernel. It should be compatible with various natively supported float16 implementations including `__half` for cuda, `float16_t` for ARM, and `Eigen::half` for Eigen to make writing customized float16 kernels easier.
+The goal of float16 is to serve as a key for the executor to find and run the correct version of compute method specialized for float16 in operator kernels. It should be compatible with various natively supported float16 implementations including `__half` for cuda, `float16_t` for ARM, and `Eigen::half` for Eigen to make writing customized float16 kernels easier.
### Compiler
- nvcc supports `__half` data type after CUDA 7.5.
@@ -95,11 +95,89 @@ float half_to_float(float16 h);
```
which provides one-to-one conversion between float32 and float16. These twos functions will do different conversion routines based on the current hardware. CUDA/ARM instrinsics will be used when the corresonding hardware is available. If the hardware or compiler level does not support float32 to float16 conversion, software emulation will be performed to do the conversion.
-## To do
-After float16 class is available, some of the future items are below:
+## float16 inference
+In Fluid, a neural network is represented as a protobuf message called [ProgramDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/program.md), whose Python wrapper is a [Program](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#program). The basic structure of a program is some nested [blocks](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#block), where each block consists of some [variable](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#variable) definitions and a sequence of [operators](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#operator). An [executor](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/executor.md) will run a given program desc by executing the sequence of operators in the entrance block of the program one by one.
-- Update pybind/tensor_py.h to bind c++ float16 with numpy float16.
+### Operator level requirement
+Each operator has many kernels for different data types, devices, and library types. The operator will select the appropriate kernel to run based on, among other things, the data type of the input variables. By default, every Fluid operator has a float data type kernel that takes float variables as input and generates float output.
-- Modify `GetKernelType()` method in `framework/operator.h` to make it compatible with float16.
+This means that if we provide float input to the first operator in a program, then each opeartor will use float kernel to compute float output and send it as input to the next operator to trigger the float kernel. Overall, the program will run in float mode and give us a final output of float data type.
-- Create a type-casting operator that can convert the data type in tensor between float16 and other types.
+The same principle applies if we want a program to run in float16 mode. We provide input variable of float16 data type to the first operator, and then one by one, each operator in the program will run the float16 kernel (provided that each operator in this program has float16 kernels registered) until we finally obtain a float16 output variable.
+
+So the preliminary requirement for float16 inference is to add float16 kernel to operators that are needed in a specific kind of program. For example, float16 inference on an image classification neural network like Vgg or Resnet, typically requires the following operators to have float16 kernels: convolution, pooling, multiplication, addition, batch norm, dropout, relu, and softmax. Please refer to [new_op_en](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/dev/new_op_en.md) for details of how to add new kernels to an operator.
+
+### Variable level requirement
+Operators including convolution and multiplication (used in fully-connected layers) takes as input not only the variables generated by the preceding operators but also [parameter](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/python_api.md#parameter) variables, which contains the trained weights to apply to the input data. These weights are obtained in the Fluid training process and are by default of float data type.
+
+When these operators are running in float16 mode, the float16 kernel requires those parameter variables to contain weights of Fluid float16 data type. Thus, we need a convenient way to convert the original float weights to float16 weights.
+
+In Fluid, we use tensor to hold actual data for a variable on the c++ end. [Pybind](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/pybind/tensor_py.h) is used to bind c++ tensors of certain data type with numpy array of the correponding numpy data type on the Python end. Each common c++ built-in data type has a corresponding numpy data type of the same name. However, since there is no built-in float16 type in c++, we cannot directly bind numpy float16 data type with the Fluid float16 class. Since both Fluid float16 and numpy float16 use uint16 as the internal data storage type, we use c++ built-in type `uint16_t` and the corresponding numpy uint16 data type to bridge the gap via [Pybind](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/pybind/tensor_py.h).
+
+The following code demonstrates how to do the tensor conversion.
+```Python
+# var is the variable of float weights
+# tensor is a numpy array of data copied from the tensor data in var
+# fp16_var is the variable that will contain float16 weights converted from var
+tensor = numpy.array(var.get_tensor())
+fp16_tensor = fp16_var.get_tensor()
+
+# After the original tensor data is converted to numpy float16 data type,
+# view(numpy.uint16) is used so that the internal memory of the numpy array
+# will be reinterpreted to be of uint16 data type, which is binded to
+# Fluid float16 class via pybind with the help of uint16_t built-in c++ type
+fp16_tensor.set(tensor.astype(numpy.float16).view(numpy.uint16), GPUPlace)
+```
+
+### Consistent API requirement
+The basic inference in float16 mode requires users to feed input and obtain output both of float16 data type. However, in this way, the inference APIs are not consistent between float16 mode and float mode, and users may find it confusing and diffcult to use float16 inference since they need to do extra steps to provide float16 input data and convert float16 output data back to float. To have consistent API for different inference modes, we need to transpile the program desc in some way so that we can run float16 inference by feeding and fetching variables of float data type.
+
+This problem can be solved by introducing a type-casting operator which takes an input variable of certain data type, cast it to another specified data type, and put the casted data into the output variable. Insert cast operator where needed can make a program internally run in float16 mode.
+
+### float16 transpiler
+Put all the above requirements in mind, we designed a float16 inference transpiler that can tranpile a float32 mode inference program desc to a float16 mode one.
+
+Given a float inference program and the corresponding variables of float32 weights in the [scope](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/concepts/scope.md),
+this transpiler mainly does the following modifications:
+
+1. Insert cast operators at the beginning of the program so that the input float data will be converted to float16 data type before feeding to subsequent operators to invoke the float16 kernel.
+
+2. Insert cast operators at the end of the program so that the output float16 data will be converted back to float data type before users obtain the result.
+
+3. For each parameter variable of float weights, create in the scope a corresponding variable of float16 weights which are converted from the corresponding float weights and add this new float16 variable to the program.
+
+4. Update the operator information in the program so that each relevant operator use the newly created float16 variable instead of its float counterpart.
+
+Below is an example of usage:
+```Python
+# Get the float inference program
+[float_inference_program, feed_target_names,
+ fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
+
+# Prepare the float input data
+tensor_img = numpy.random.rand(1, 3, 32, 32).astype(numpy.float32)
+
+# Running inference_program in float mode
+float_results = exe.run(float_inference_program,
+ feed={feed_target_names[0]: tensor_img},
+ fetch_list=fetch_targets)
+
+# Use float16 transpiler to speedup
+float16_inference_program = float_inference_program.clone()
+t = fluid.InferenceTranspiler()
+t.float16_transpile(float16_inference_program, GPUPlace)
+
+# Running
+float16_results = exe.run(float16_inference_program,
+ feed={feed_target_names[0]: tensor_img},
+ fetch_list=fetch_targets)
+```
+
+As we can see from the example above, users can simply use the `float16_transpile` method provided by the infernece transpiler class on an existing float inference program to run inference in float16 mode.
+
+### Speedup on GPU
+Currently, Fluid inference in float16 mode is only supported on Nvidia GPU device. There is no motivation to support float16 inference on non-ARM CPUs because float16 is not natively supported there and float16 calculation will only be slower than its float counterpart.
+
+Nvidia started to support its native float16 data type (which has the same internal memory representation as Fluid float16 class) on CUDA 7.5. Moreover, float16 speedups on common computational intensive tasks including GEMM (general matrix-matrix multiplication) and convolution are supported since cublas 7.5 and cuDNN 5.0.
+
+Recently, the introduction of [tensor core](https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/) in volta architecture GPUs and the support of tensor core calculation in CUDA 9.0 and cuDNN 7.0 make float16 truly superior to float in certain deep learning applications. Please refer to this [benchmark report](https://github.com/kexinzhao/Paddle_benchmark/blob/master/float16_benchmark.md) for more details.
diff --git a/doc/fluid/design/onnx/images/project_structure.png b/doc/fluid/design/onnx/images/project_structure.png
new file mode 100644
index 0000000000..ab1c2ff23c
Binary files /dev/null and b/doc/fluid/design/onnx/images/project_structure.png differ
diff --git a/doc/fluid/design/onnx/onnx_convertor.md b/doc/fluid/design/onnx/onnx_convertor.md
new file mode 100644
index 0000000000..bc1665d7c3
--- /dev/null
+++ b/doc/fluid/design/onnx/onnx_convertor.md
@@ -0,0 +1,131 @@
+# Background
+
+[ONNX (Open Neural Network Exchange)](https://github.com/onnx/onnx) bridges different deep learning frameworks by providing an open source graph format for models. The models trained in other frameworks can be converted into the ONNX format to execute inference by utilizing the built-in operators in ONNX - this is called a **frontend**. With the inverse conversion (called a **backend**), different frameworks can share any models supported by ONNX in principle. Now most mainstream frameworks have joined the ONNX community, e.g. Caffe2, PyTorch, and MXNet etc. And there is a momentum driving more and more vendors to begin supporting ONNX or even choose ONNX as the only machine learning runtime in their devices.
+
+Therefore, it is necessary to enable the conversion between PaddlePaddle and ONNX. This design doc is aimed at implementing a convertor, mainly for converting between **Fluid** models and ONNX (it is very likely that we may support older v2 models in the future). A complete convertor should be bidirectional - with a frontend AND a backend, but considering the importance, the we will start with the frontend i.e. Fluid models to ONNX models.
+
+
+# How it works
+
+ONNX has a [working list of operators](https://github.com/onnx/onnx/blob/master/docs/Operators.md) which is versioned.
+
+When prioritizing implementation of a frontend over a backend, choice of coverage of Fluid -> ONNX operators comes down to choices of models to be supported (see section `Supported models`). Eventually, this will allow us to reach a really-wide coverage of all operators.
+
+Here are a few major considerations when it comes to converting models:
+
+- **Op-level conversion**: How to map the inputs, attributes, and outputs of each Paddle operator to those of the ONNX operator. In several cases, these require transformations. For each direction (frontend vs. backend), a different conversion mapping is needed.
+- **Parameters (weights) initialization**: Setting initial parameters on different nodes.
+- **Tensor data type mapping** (Note: Some ONNX data types are not supported in Fluid)
+- **Network representation adaption**: Fluid `ProgramDesc` include nested blocks. Since ONNX is free of nesting, the `ProgramDesc` ops need to be traversed to only include ops from the global scope in the root block. The variables used as inputs and outputs should also be in this scope.
+- **Model validation**: There are two kinds of validations that are necessary:
+ 1. We need to ensure that the inference outputs of the ops in run inside a model are the same as those when running the ONNX converted ops through an alternative ONNX backend.
+ 2. Checking to see if the generated nodes on the graph are validated by the internal ONNX checkers.
+- **Versioning**: ONNX versions its op listing over versions. In fact, it has versioning on 3 different levels: ops, graphs, and ONNX models. This requires that we are conscious about versioning the convertor and updating tests and op convertor logic for each release. It also implies that we release pre-trained ONNX models upon each version release.
+
+One thing that makes this conversion more feasible in Fluid's case is the use of a static IR - the `ProgramDesc` - as opposed to a dynamic graph, as created in the cases of frameworks like PyTorch.
+
+
+# Project structure
+
+
+
+
+
+The project contains four important parts:
+
+* **fluid**: The directory that contains wrappers for fluid related APIs. Fluid has provided some low-level APIs to parse or generate the inference model. However, directly using these low-level APIs makes the code tediously long. This module wraps low-level APIs to provide simplified interfaces.
+
+* **onnx**: This is a Python package provided by ONNX containing helpers for creating nodes, graphs, and eventually binary protobuf models with initializer parameters.
+
+* **onnx_fluid**: Contains two-way mapping (Fluid -> ONNX ops and ONNX -> Fluid ops). Called from `convert.py`, the program uses this mapping along with modifier functions to construct ONNX nodes with the help of ONNX's `make_node` helper. It also contains mapping between datatypes and tensor deprecation / amplification logic.
+
+* **convert.py**: The interface exposed to users. This will traverse the global program blocks/variables and construct the write-able model.
+
+
+# Usage
+The converter should be designed to very easy-to-use. Bidirectional conversion between a Fluid inference model and an ONNX binary model will be supported. Model validation will also provided to verify the correctness of converted model.
+
+* Convert Fluid inference model to ONNX binary model
+
+ ```
+ python convert.py --fluid_model --onnx_model validate True
+ ```
+
+* Validate the converted model
+
+ ```
+ python validate.py --fluid_model --onnx_model
+ ```
+
+The conversion and model validation will be completed consecutively, finally output a readable model structure description. And for the converse conversion, users only need to exchange the input and output.
+
+
+# Challenges and mitigation
+
+## Cycles
+
+Cycles are unsupported in ONNX. In Paddle, the `while` op is the most prominent example of a cycle.
+
+*Resolution*: We won't support models with `while`s which can't be substituted until ONNX adds support for such ops.
+
+## Sequences
+
+Sequence processing operators like `sequence_expand`, `sequence_reshape`, `sequence_concat`, and `sequence_pool` are not supported by ONNX as well, because they do not support non-padded datatypes like LoDTensors.
+
+*Resolution*: Since the runtimes using our ONNX exported graphs won't be using LoDTensors in the first place, such sequence operators should be mapped to ONNX ops that will do the necessary transposing ops with the knowledge of the padding and shape of the Tensors.
+
+## Ops that can't easily be mapped
+
+There are ops that just aren't possible to map today:
+
+**Control flow operators**
+
+Paddle supports control flow ops like `If/Else` and `Switch` (if we ignore the CSP operations like `select` for now). ONNX has `If` support in the experimental phase.
+
+*Resolution*: Map Paddle's `If/Else` to ONNX's `If`, but ignore other control flow operators until ONNX brings support for them.
+
+
+**Non-existent in Fluid**
+
+There are several ONNX operators that are not available in Fluid today, e.g. `InstanceNormalization`, `RandomUniform`, `Unsqueeze`, etc.
+
+*Resolution*: For the initial phase, we can choose to not support ops that our models don't care for and are subsequently not available in Fluid. However, for ops that we think might be necessary for Fluid users also, we must implement them on our side and support the ONNX conversion to them. This list is TBD.
+
+
+**Concurrency**
+
+ONNX does not have any considerations for concurrency right now.
+
+*Resolution*: There are two ways to approach this:
+
+a. We choose to not support concurrent models.
+b. We only support `go_op`s (basically threads) shallowly. This could mean that we enqueue `go_op` ops prior to gradient calculations OR even prior to the entire graph, and that's it - since `go_op`s do not have support for backprop anyways. One of the core target use cases of `go_op`: batch reading - can be handled through this approach.
+
+
+**Overloaded in Fluid**
+
+There are ops in ONNX whose job can't be accomplished by a single corresponding Paddle operator (e.g. ), but a collection of operators.
+
+*Resolution*: Chain multiple Paddle operators.
+
+
+## Lack of LoDTensors
+
+As stated above, ONNX only supports simple Tensor values.
+
+*Resolution*: Deprecate to plain old numpy-able tensors.
+
+
+## Reconstruction from deprecated ONNX ops
+
+For higher-level Fluid ops, such as a few offered by the `nn` layer that do not have direct corresponding mappings but can be converted to ONNX by chaining a series of ops without cycles, it would be useful to map them back to the higher-level Fluid ops once converted back from the deprecated ONNX graphs.
+
+*Resolution*: Graphs that have the deprecation from Paddle -> ONNX. When converting back from ONNX, if we encounter the identical graphs by doing a forward search, we can replace the subgraphs with the matching ONNX op.
+
+
+# Supported models
+
+As mentioned above, potential risks may come from the conversion of sequence-related models, including the LodTensor, ```if/else``` and ```while``` operator. So a good choice is to focus on some important feedforward models first, then implement some simple recurrent models.
+
+- Feedforward models: common models selected in PaddleBook, e.g. VGG, ResNet and some other models proposed by application teams.
+- Recurrent models: language model, stacked LSTMs etc.
diff --git a/doc/v2/api/data/data_reader.rst b/doc/v2/api/data/data_reader.rst
index d7c896a627..1a35d0bbc8 100644
--- a/doc/v2/api/data/data_reader.rst
+++ b/doc/v2/api/data/data_reader.rst
@@ -56,11 +56,11 @@ DataFeeder
Reader
======
-.. automodule:: paddle.v2.reader
+.. automodule:: paddle.reader
:members:
:noindex:
-.. automodule:: paddle.v2.reader.creator
+.. automodule:: paddle.reader.creator
:members:
:noindex:
diff --git a/doc/v2/api/data/dataset.rst b/doc/v2/api/data/dataset.rst
index 02e41564b1..e7c8be4452 100644
--- a/doc/v2/api/data/dataset.rst
+++ b/doc/v2/api/data/dataset.rst
@@ -1,82 +1,82 @@
Dataset
=======
-.. automodule:: paddle.v2.dataset
+.. automodule:: paddle.dataset
:members:
:noindex:
mnist
+++++
-.. automodule:: paddle.v2.dataset.mnist
+.. automodule:: paddle.dataset.mnist
:members:
:noindex:
cifar
+++++
-.. automodule:: paddle.v2.dataset.cifar
+.. automodule:: paddle.dataset.cifar
:members:
:noindex:
conll05
+++++++
-.. automodule:: paddle.v2.dataset.conll05
+.. automodule:: paddle.dataset.conll05
:members: get_dict,get_embedding,test
:noindex:
imdb
++++
-.. automodule:: paddle.v2.dataset.imdb
+.. automodule:: paddle.dataset.imdb
:members:
:noindex:
imikolov
++++++++
-.. automodule:: paddle.v2.dataset.imikolov
+.. automodule:: paddle.dataset.imikolov
:members:
:noindex:
movielens
+++++++++
-.. automodule:: paddle.v2.dataset.movielens
+.. automodule:: paddle.dataset.movielens
:members:
:noindex:
-.. autoclass:: paddle.v2.dataset.movielens.MovieInfo
+.. autoclass:: paddle.dataset.movielens.MovieInfo
:noindex:
-
-.. autoclass:: paddle.v2.dataset.movielens.UserInfo
+
+.. autoclass:: paddle.dataset.movielens.UserInfo
:noindex:
sentiment
+++++++++
-.. automodule:: paddle.v2.dataset.sentiment
+.. automodule:: paddle.dataset.sentiment
:members:
:noindex:
uci_housing
+++++++++++
-.. automodule:: paddle.v2.dataset.uci_housing
+.. automodule:: paddle.dataset.uci_housing
:members:
:noindex:
wmt14
+++++
-.. automodule:: paddle.v2.dataset.wmt14
+.. automodule:: paddle.dataset.wmt14
:members:
:noindex:
wmt16
+++++
-.. automodule:: paddle.v2.dataset.wmt16
+.. automodule:: paddle.dataset.wmt16
:members:
:noindex:
diff --git a/doc/v2/howto/cluster/multi_cluster/index_en.rst b/doc/v2/howto/cluster/multi_cluster/index_en.rst
index dac7aaef08..b69bd5b2db 100644
--- a/doc/v2/howto/cluster/multi_cluster/index_en.rst
+++ b/doc/v2/howto/cluster/multi_cluster/index_en.rst
@@ -1,19 +1,35 @@
Use different clusters
======================
-PaddlePaddle supports running jobs on several platforms including:
-- `Kubernetes `_ open-source system for automating deployment, scaling, and management of containerized applications from Google.
-- `OpenMPI `_ Mature high performance parallel computing framework.
-- `Fabric `_ A cluster management tool. Write scripts to submit jobs or manage the cluster.
+The user's cluster environment is not the same. To facilitate everyone's deployment, we provide a variety of cluster deployment methods to facilitate the submission of cluster training tasks, which will be introduced as follows:
-We'll introduce cluster job management on these platforms. The examples can be found under `cluster_train_v2 `_ .
+`Kubernetes `_ is a scheduling framework of Google open source container cluster, supporting a complete cluster solution for large-scale cluster production environment. The following guidelines show PaddlePaddle's support for Kubernetes:
-These cluster platforms provide API or environment variables for training processes, when the job is dispatched to different nodes. Like node ID, IP or total number of nodes etc.
+.. toctree::
+ :maxdepth: 1
+
+ k8s_cn.md
+ k8s_distributed_cn.md
+
+`OpenMPI `_ is a mature high-performance parallel computing framework, which is widely used in the field of HPC. The following guide describes how to use OpenMPI to build PaddlePaddle's cluster training task:
.. toctree::
:maxdepth: 1
- fabric_en.md
- openmpi_en.md
- k8s_en.md
- k8s_aws_en.md
+ openmpi_cn.md
+
+`Fabric `_ is a convenient tool for program deployment and management. We provide a way to deploy and manage with Fabric. If you want to know more about it, please read the following guidelines:
+
+.. toctree::
+ :maxdepth: 1
+
+ fabric_cn.md
+
+We also support the deployment of PaddlePaddle on AWS. Learn more about:
+
+.. toctree::
+ :maxdepth: 1
+
+ k8s_aws_cn.md
+
+The examples can be found under `cluster_train_v2 `_ .
\ No newline at end of file
diff --git a/paddle/fluid/framework/blocking_queue.h b/paddle/fluid/framework/blocking_queue.h
new file mode 100644
index 0000000000..a19558c0ae
--- /dev/null
+++ b/paddle/fluid/framework/blocking_queue.h
@@ -0,0 +1,74 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include // NOLINT
+#include
+#include // NOLINT
+#include
+
+namespace paddle {
+namespace framework {
+
+template
+class BlockingQueue {
+ public:
+ void Push(const T &item) {
+ {
+ std::lock_guard g(mutex_);
+ q_.emplace_back(item);
+ }
+ cv_.notify_one();
+ }
+
+ template
+ void Extend(const U &items) {
+ {
+ std::lock_guard g(mutex_);
+ for (auto &item : items) {
+ q_.emplace_back(item);
+ }
+ }
+ cv_.notify_all();
+ }
+
+ std::deque PopAll(size_t ms, bool *timeout) {
+ auto time =
+ std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
+ std::unique_lock lock(mutex_);
+ *timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); });
+ std::deque ret;
+ if (!*timeout) {
+ std::swap(ret, q_);
+ }
+ return ret;
+ }
+
+ T Pop() {
+ std::unique_lock lock(mutex_);
+ cv_.wait(lock, [=] { return !q_.empty(); });
+ T rc(std::move(q_.front()));
+ q_.pop_front();
+ return rc;
+ }
+
+ private:
+ std::mutex mutex_;
+ std::condition_variable cv_;
+ std::deque q_;
+};
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc
index bfad9ac1e9..9c277a27da 100644
--- a/paddle/fluid/framework/data_transform.cc
+++ b/paddle/fluid/framework/data_transform.cc
@@ -63,16 +63,16 @@ void DataTransform(const OpKernelType& expected_kernel_type,
}
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
- Variable& out_var) {
+ Variable* out_var) {
if (in_var.IsType()) {
auto& in_lod_tensor = in_var.Get();
- auto* tran_lod_tensor = out_var.GetMutable();
+ auto* tran_lod_tensor = out_var->GetMutable();
tran_lod_tensor->set_lod(in_lod_tensor.lod());
tran_lod_tensor->set_layout(in_lod_tensor.layout());
tran_lod_tensor->ShareDataWith(tensor);
} else if (in_var.IsType()) {
auto& in_selected_rows = in_var.Get();
- auto* trans_selected_rows = out_var.GetMutable();
+ auto* trans_selected_rows = out_var->GetMutable();
trans_selected_rows->set_height(in_selected_rows.height());
trans_selected_rows->set_rows(in_selected_rows.rows());
trans_selected_rows->mutable_value()->ShareDataWith(tensor);
diff --git a/paddle/fluid/framework/data_transform.h b/paddle/fluid/framework/data_transform.h
index 9ec67e6f3d..dee5d8c7c1 100644
--- a/paddle/fluid/framework/data_transform.h
+++ b/paddle/fluid/framework/data_transform.h
@@ -35,7 +35,7 @@ void DataTransform(const OpKernelType& expected_kernel_type,
const Tensor& input_tensor, Tensor* out);
void CopyVariableWithTensor(const Variable& in_var, const Tensor& tensor,
- Variable& out_var);
+ Variable* out_var);
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/details/broadcast_op_handle_test.cc b/paddle/fluid/framework/details/broadcast_op_handle_test.cc
index 3f2dcde3e9..8f1b6d1615 100644
--- a/paddle/fluid/framework/details/broadcast_op_handle_test.cc
+++ b/paddle/fluid/framework/details/broadcast_op_handle_test.cc
@@ -139,7 +139,7 @@ struct TestBroadcastOpHandle {
PADDLE_ENFORCE_EQ(out_tensor.lod(), lod, "lod is not equal.");
f::Tensor result_tensor;
- f::TensorCopy(out_tensor, cpu_place, *(ctxs_[j]), &result_tensor);
+ f::TensorCopySync(out_tensor, cpu_place, &result_tensor);
float* ct = result_tensor.mutable_data(cpu_place);
for (int64_t i = 0; i < f::product(kDims); ++i) {
@@ -185,7 +185,7 @@ struct TestBroadcastOpHandle {
}
f::Tensor result_tensor;
- f::TensorCopy(rt, cpu_place, *(ctxs_[j]), &result_tensor);
+ f::TensorCopySync(rt, cpu_place, &result_tensor);
float* ct = result_tensor.data();
for (int64_t i = 0; i < f::product(kDims); ++i) {
diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc
index 423449abff..1e8ca20b51 100644
--- a/paddle/fluid/framework/details/fetch_op_handle.cc
+++ b/paddle/fluid/framework/details/fetch_op_handle.cc
@@ -66,8 +66,7 @@ void FetchOpHandle::RunImpl() {
auto &t = var->Get();
if (platform::is_gpu_place(t.place())) {
#ifdef PADDLE_WITH_CUDA
- TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i], true);
- dev_ctxes_.at(t.place())->Wait();
+ TensorCopySync(t, cpu, &tensors_[i]);
#endif
} else {
tensors_[i].ShareDataWith(t);
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 3413467b14..c2eb1c31b4 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -58,23 +58,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
const OpDesc &op,
- const platform::Place &p,
- const size_t &i) const {
+ size_t place_id) const {
+ auto p = places_[place_id];
auto *op_handle = result->ops_.back().get();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
- auto var_names = op.InputArgumentNames();
-
- for (auto &each_var_name : var_names) {
- VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
+ for (auto &each_var_name : op.InputArgumentNames()) {
+ VarHandle *var =
+ CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
op_handle->AddInput(var);
}
- var_names = op.OutputArgumentNames();
-
- for (auto &each_var_name : var_names) {
- CreateOpOutput(result, op_handle, each_var_name, p, i);
+ for (auto &each_var_name : op.OutputArgumentNames()) {
+ CreateOpOutput(result, op_handle, each_var_name, p, place_id);
}
}
@@ -84,17 +81,18 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return false;
}
- auto checker = [&](const std::vector opvars,
- const std::vector sendvars) -> bool {
- bool is_dist_train_op = false;
+ /**
+ * Check any of opvars contains `.block` and in sendvars
+ */
+ auto checker = [](const std::vector &opvars,
+ const std::vector &sendvars) -> bool {
for (auto &var : opvars) {
if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
- is_dist_train_op = true;
- break;
+ return true;
}
}
- return is_dist_train_op;
+ return false;
};
if (op.Type() == "split") {
@@ -117,13 +115,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
places_.size());
// Find "send" op first for split is in front of send.
- OpDesc *send_op = nullptr;
- for (auto *op : program.Block(0).AllOps()) {
- if (op->Type() == "send") {
- send_op = op;
- break;
- }
- }
+ OpDesc *send_op = GetSendOpDesc(program);
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
@@ -134,6 +126,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
+ // user can customize loss@grad if skip_scale_loss_
if (!skip_scale_loss_) {
CreateScaleLossGradOp(&result);
}
@@ -142,10 +135,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) {
// Currently, we assume that once gradient is generated, it can be
- // broadcast, and each gradient is only broadcast once. But there are no
- // other cases, for example, we need to adjust the gradient according to
- // the input when we get the gradient, which is not considered at
- // present.
+ // broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
InsertNCCLAllReduceOp(&result, og);
@@ -175,6 +165,16 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
return std::unique_ptr(graph);
}
+OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
+ const ProgramDesc &program) const {
+ for (auto *op : program.Block(0).AllOps()) {
+ if (op->Type() == "send") {
+ return op;
+ }
+ }
+ return nullptr;
+}
+
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA
@@ -243,7 +243,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
- CreateOpHandleIOs(result, op, p, scope_idx);
+ CreateOpHandleIOs(result, op, scope_idx);
}
}
@@ -255,7 +255,7 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
result->ops_.emplace_back(new SendOpHandle(op, s, p));
// Create inputs for output on original place and no ssa output
// is created for send op.
- CreateOpHandleIOs(result, op, p, 0);
+ CreateOpHandleIOs(result, op, 0);
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index dc3da70eda..fa4d31bdc4 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
- const platform::Place &p, const size_t &i) const;
+ size_t place_id) const;
private:
std::string loss_var_name_;
@@ -65,6 +65,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
+ /**
+ * Is this operator as the end-point operator before/after send operator.
+ */
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
@@ -77,6 +80,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unordered_set *og_has_been_broadcast) const;
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
+
+ /**
+ * Get send op in the global block of program.
+ * nullptr if not found.
+ */
+ OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
};
} // namespace details
} // namespace framework
diff --git a/paddle/fluid/framework/details/reduce_op_handle_test.cc b/paddle/fluid/framework/details/reduce_op_handle_test.cc
index c17aabee53..ffdd7c14eb 100644
--- a/paddle/fluid/framework/details/reduce_op_handle_test.cc
+++ b/paddle/fluid/framework/details/reduce_op_handle_test.cc
@@ -194,7 +194,7 @@ struct TestReduceOpHandle {
}
f::Tensor result_tensor;
- f::TensorCopy(rt, cpu_place, *(ctxs_[output_scope_idx]), &result_tensor);
+ f::TensorCopySync(rt, cpu_place, &result_tensor);
float *ct = result_tensor.data();
for (int64_t j = 0; j < f::product(result_tensor.dims()); ++j) {
@@ -239,7 +239,7 @@ struct TestReduceOpHandle {
auto &rt = out_var->Get();
f::Tensor result_tensor;
- f::TensorCopy(rt, cpu_place, *(ctxs_[output_scope_idx]), &result_tensor);
+ f::TensorCopySync(rt, cpu_place, &result_tensor);
float *ct = result_tensor.data();
for (int64_t j = 0; j < f::product(result_tensor.dims()); ++j) {
diff --git a/paddle/fluid/framework/details/ssa_graph.h b/paddle/fluid/framework/details/ssa_graph.h
index 72684e7f97..e996a00c16 100644
--- a/paddle/fluid/framework/details/ssa_graph.h
+++ b/paddle/fluid/framework/details/ssa_graph.h
@@ -25,12 +25,22 @@ namespace paddle {
namespace framework {
namespace details {
+// A SSA graph used by parallel executor.
struct SSAGraph {
+ // all variable in each devices.
+ // The outside vector is the device vector. Each element of this vector is a
+ // map from variable name to variables. The variables, who have the same name,
+ // will have a different version. The offset in the
+ // `std::vector>` is the version of varaibles.
std::vector<
std::unordered_map>>>
vars_;
+
// aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set> dep_vars_;
+
+ // all operators. NOTE that even we use a vector here, the operators is
+ // unordered.
std::vector> ops_;
};
diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h
index be1f0460e4..64e5d93081 100644
--- a/paddle/fluid/framework/details/ssa_graph_builder.h
+++ b/paddle/fluid/framework/details/ssa_graph_builder.h
@@ -48,6 +48,8 @@ class SSAGraphBuilder {
const platform::Place &place,
size_t place_offset);
+ // Add an output variable (each_var_name, place, place_offset) to op_handle,
+ // which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset);
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
index d70bbd4ef0..d089b79d91 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
@@ -22,6 +22,7 @@
#include
#include "ThreadPool.h" // ThreadPool in thrird party
+#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
namespace paddle {
@@ -30,46 +31,6 @@ class Scope;
namespace details {
-template
-class BlockingQueue {
- public:
- void Push(const T &item) {
- {
- std::lock_guard g(mutex_);
- q_.emplace_back(item);
- }
- cv_.notify_one();
- }
-
- template
- void Extend(const U &items) {
- {
- std::lock_guard g(mutex_);
- for (auto &item : items) {
- q_.emplace_back(item);
- }
- }
- cv_.notify_all();
- }
-
- std::deque PopAll(size_t ms, bool *timeout) {
- auto time =
- std::chrono::system_clock::now() + std::chrono::milliseconds(ms);
- std::unique_lock lock(mutex_);
- *timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); });
- std::deque ret;
- if (!*timeout) {
- std::swap(ret, q_);
- }
- return ret;
- }
-
- private:
- std::mutex mutex_;
- std::condition_variable cv_;
- std::deque q_;
-};
-
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc
index 513e720fd0..766bf0ab0c 100644
--- a/paddle/fluid/framework/executor.cc
+++ b/paddle/fluid/framework/executor.cc
@@ -226,15 +226,15 @@ static bool has_fetch_operators(
}
void Executor::Run(const ProgramDesc& program, Scope* scope,
- std::map& feed_targets,
- std::map& fetch_targets,
+ std::map* feed_targets,
+ std::map* fetch_targets,
bool create_vars, const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
bool has_feed_ops =
- has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
+ has_feed_operators(program.Block(0), *feed_targets, feed_holder_name);
bool has_fetch_ops =
- has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name);
+ has_fetch_operators(program.Block(0), *fetch_targets, fetch_holder_name);
ProgramDesc* copy_program = const_cast(&program);
if (!has_feed_ops || !has_fetch_ops) {
@@ -250,7 +250,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
feed_holder->SetPersistable(true);
int i = 0;
- for (auto& feed_target : feed_targets) {
+ for (auto& feed_target : (*feed_targets)) {
std::string var_name = feed_target.first;
VLOG(3) << "feed target's name: " << var_name;
@@ -273,7 +273,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
fetch_holder->SetPersistable(true);
int i = 0;
- for (auto& fetch_target : fetch_targets) {
+ for (auto& fetch_target : (*fetch_targets)) {
std::string var_name = fetch_target.first;
VLOG(3) << "fetch target's name: " << var_name;
@@ -361,16 +361,16 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope,
- std::map& feed_targets,
- std::map& fetch_targets, bool create_vars,
+ std::map* feed_targets,
+ std::map* fetch_targets, bool create_vars,
const std::string& feed_holder_name, const std::string& fetch_holder_name) {
auto& global_block = ctx->prog_.Block(ctx->block_id_);
PADDLE_ENFORCE(
- has_feed_operators(global_block, feed_targets, feed_holder_name),
+ has_feed_operators(global_block, *feed_targets, feed_holder_name),
"Program in ExecutorPrepareContext should has feed_ops.");
PADDLE_ENFORCE(
- has_fetch_operators(global_block, fetch_targets, fetch_holder_name),
+ has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
"Program in the prepared context should has fetch_ops.");
// map the data of feed_targets to feed_holder
@@ -378,8 +378,8 @@ void Executor::RunPreparedContext(
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get(op->GetAttr("col"));
- SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
- idx);
+ SetFeedVariable(scope, *(*feed_targets)[feed_target_name],
+ feed_holder_name, idx);
}
}
@@ -390,7 +390,7 @@ void Executor::RunPreparedContext(
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get(op->GetAttr("col"));
- *fetch_targets[fetch_target_name] =
+ *(*fetch_targets)[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h
index 43defdacf2..4a3d637e2d 100644
--- a/paddle/fluid/framework/executor.h
+++ b/paddle/fluid/framework/executor.h
@@ -55,8 +55,8 @@ class Executor {
bool create_local_scope = true, bool create_vars = true);
void Run(const ProgramDesc& program, Scope* scope,
- std::map& feed_targets,
- std::map& fetch_targets,
+ std::map* feed_targets,
+ std::map* fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
@@ -74,8 +74,8 @@ class Executor {
bool create_vars = true);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
- std::map& feed_targets,
- std::map& fetch_targets,
+ std::map* feed_targets,
+ std::map* fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
diff --git a/paddle/fluid/framework/init.cc b/paddle/fluid/framework/init.cc
index b30f276b4b..85beae775b 100644
--- a/paddle/fluid/framework/init.cc
+++ b/paddle/fluid/framework/init.cc
@@ -15,7 +15,6 @@ limitations under the License. */
#include
#include
#include
-#include
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/operator.h"
@@ -31,6 +30,7 @@ std::once_flag p2p_init_flag;
void InitGflags(std::vector argv) {
std::call_once(gflags_init_flag, [&]() {
+ argv.insert(argv.begin(), "dummy");
int argc = argv.size();
char **arr = new char *[argv.size()];
std::string line;
@@ -44,20 +44,23 @@ void InitGflags(std::vector argv) {
});
}
-void InitP2P(int count) {
+void InitP2P(std::vector devices) {
#ifdef PADDLE_WITH_CUDA
std::call_once(p2p_init_flag, [&]() {
+ int count = devices.size();
for (int i = 0; i < count; ++i) {
for (int j = 0; j < count; ++j) {
- if (i == j) continue;
+ if (devices[i] == devices[j]) continue;
int can_acess = -1;
- PADDLE_ENFORCE(cudaDeviceCanAccessPeer(&can_acess, i, j),
- "Failed to test P2P access.");
+ PADDLE_ENFORCE(
+ cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]),
+ "Failed to test P2P access.");
if (can_acess != 1) {
- LOG(WARNING) << "Cannot enable P2P access from " << i << " to " << j;
+ LOG(WARNING) << "Cannot enable P2P access from " << devices[i]
+ << " to " << devices[j];
} else {
- cudaSetDevice(i);
- cudaDeviceEnablePeerAccess(j, 0);
+ cudaSetDevice(devices[i]);
+ cudaDeviceEnablePeerAccess(devices[j], 0);
}
}
}
@@ -67,11 +70,26 @@ void InitP2P(int count) {
void InitDevices(bool init_p2p) {
/*Init all available devices by default */
+ std::vector devices;
+#ifdef PADDLE_WITH_CUDA
+ try {
+ int count = platform::GetCUDADeviceCount();
+ for (int i = 0; i < count; ++i) {
+ devices.push_back(i);
+ }
+ } catch (const std::exception &exp) {
+ LOG(WARNING) << "Compiled with WITH_GPU, but no GPU found in runtime.";
+ }
+#else
+ LOG(WARNING)
+ << "'CUDA' is not supported, Please re-compile with WITH_GPU option";
+#endif
+ InitDevices(init_p2p, devices);
+}
+void InitDevices(bool init_p2p, const std::vector devices) {
std::vector places;
- places.emplace_back(platform::CPUPlace());
int count = 0;
-
#ifdef PADDLE_WITH_CUDA
try {
count = platform::GetCUDADeviceCount();
@@ -83,12 +101,17 @@ void InitDevices(bool init_p2p) {
<< "'CUDA' is not supported, Please re-compile with WITH_GPU option";
#endif
- for (int i = 0; i < count; ++i) {
- places.emplace_back(platform::CUDAPlace(i));
+ for (size_t i = 0; i < devices.size(); ++i) {
+ if (devices[i] >= count || devices[i] < 0) {
+ LOG(WARNING) << "Invalid devices id.";
+ continue;
+ }
+ places.emplace_back(platform::CUDAPlace(devices[i]));
}
if (init_p2p) {
- InitP2P(count);
+ InitP2P(devices);
}
+ places.emplace_back(platform::CPUPlace());
platform::DeviceContextPool::Init(places);
}
diff --git a/paddle/fluid/framework/init.h b/paddle/fluid/framework/init.h
index 1155ca3604..0e30594672 100644
--- a/paddle/fluid/framework/init.h
+++ b/paddle/fluid/framework/init.h
@@ -28,5 +28,7 @@ void InitGLOG(const std::string &prog_name);
void InitDevices(bool init_p2p);
+void InitDevices(bool init_p2p, const std::vector devices);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc
index 46c834b38b..076c457130 100644
--- a/paddle/fluid/framework/op_desc.cc
+++ b/paddle/fluid/framework/op_desc.cc
@@ -205,8 +205,8 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
need_update_ = true;
}
-void OpDesc::SetBlockAttr(const std::string &name, BlockDesc &block) {
- this->attrs_[name] = █
+void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
+ this->attrs_[name] = block;
need_update_ = true;
}
diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h
index cd6777e60a..3ee36a47c1 100644
--- a/paddle/fluid/framework/op_desc.h
+++ b/paddle/fluid/framework/op_desc.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include
#include
#include "paddle/fluid/framework/attribute.h"
@@ -73,7 +74,7 @@ class OpDesc {
void SetAttr(const std::string &name, const Attribute &v);
- void SetBlockAttr(const std::string &name, BlockDesc &block);
+ void SetBlockAttr(const std::string &name, BlockDesc *block);
Attribute GetAttr(const std::string &name) const;
diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc
index f97bd08274..32576423a6 100644
--- a/paddle/fluid/framework/operator.cc
+++ b/paddle/fluid/framework/operator.cc
@@ -171,17 +171,6 @@ std::string OperatorBase::DebugStringEx(const Scope* scope) const {
return ss.str();
}
-void OperatorBase::Rename(const std::string& old_name,
- const std::string& new_name) {
- for (auto& input : inputs_) {
- std::replace(input.second.begin(), input.second.end(), old_name, new_name);
- }
- for (auto& output : outputs_) {
- std::replace(output.second.begin(), output.second.end(), old_name,
- new_name);
- }
-}
-
OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& inputs,
const VariableNameMap& outputs,
@@ -327,7 +316,6 @@ bool OpSupportGPU(const std::string& op_type) {
auto it = all_kernels.find(op_type);
if (it == all_kernels.end()) {
// All control operator must support GPU
-
return true;
}
for (auto& kern_pair : it->second) {
@@ -554,7 +542,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::shared_ptr out(new Tensor);
DataTransform(expected_kernel_key, kernel_type_for_var, *tensor_in,
out.get());
- CopyVariableWithTensor(*var, *(out.get()), *trans_var);
+ CopyVariableWithTensor(*var, *(out.get()), trans_var);
}
}
}
diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h
index b7a7c69b4c..826cc57b72 100644
--- a/paddle/fluid/framework/operator.h
+++ b/paddle/fluid/framework/operator.h
@@ -79,31 +79,28 @@ class OperatorBase {
virtual ~OperatorBase() {}
- template
- inline const T& Attr(const std::string& name) const {
- PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
- name);
- return boost::get(attrs_.at(name));
- }
-
- /// if scope is not null, also show dimensions of arguments
- virtual std::string DebugStringEx(const Scope* scope) const;
-
- std::string DebugString() const { return DebugStringEx(nullptr); }
-
- /// Net will call this interface function to Run an op.
+ /// Executor will call this interface function to Run an op.
// The implementation should be written at RunImpl
void Run(const Scope& scope, const platform::Place& place);
// FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
virtual void Stop() {}
- virtual bool IsNetOp() const { return false; }
+ /// if scope is not null, also show dimensions of arguments
+ virtual std::string DebugStringEx(const Scope* scope) const;
+ std::string DebugString() const { return DebugStringEx(nullptr); }
virtual bool SupportGPU() const { return false; }
- /// rename inputs outputs name
- void Rename(const std::string& old_name, const std::string& new_name);
+ const std::string& Type() const { return type_; }
+
+ template
+ inline const T& Attr(const std::string& name) const {
+ PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
+ name);
+ return boost::get(attrs_.at(name));
+ }
+ const AttributeMap& Attrs() const { return attrs_; }
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
@@ -112,7 +109,7 @@ class OperatorBase {
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector& Inputs(const std::string& name) const;
-
+ //! Get all inputs variable names
std::vector InputVars() const;
//! Get a output with argument's name described in `op_proto`
@@ -120,13 +117,9 @@ class OperatorBase {
//! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy.
const std::vector& Outputs(const std::string& name) const;
-
+ //! Get all outputs variable names
virtual std::vector OutputVars(bool has_intermediate) const;
- const std::string& Type() const { return type_; }
- void SetType(const std::string& type) { type_ = type; }
- const AttributeMap& Attrs() const { return attrs_; }
-
// Return a new operator instance, which is as same as this.
// Use unique_ptr to prevent caller forget to delete this pointer.
virtual std::unique_ptr Clone() const = 0;
@@ -278,20 +271,6 @@ class ExecutionContext {
return res;
}
- void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
- size_t j = 0) const {
- PADDLE_ENFORCE_LT(i, InputSize(in));
- PADDLE_ENFORCE_LT(j, OutputSize(out));
- auto* in_var = MultiInputVar(in)[i];
- auto* out_var = MultiOutputVar(out)[j];
- if (!in_var->IsType()) return;
- PADDLE_ENFORCE(out_var->IsType(),
- "The %d-th output of Output(%s) must be LoDTensor.", j, out);
- auto in_tensor = in_var->Get();
- auto* out_tensor = out_var->GetMutable();
- out_tensor->set_lod(in_tensor.lod());
- }
-
platform::Place GetPlace() const { return device_context_.GetPlace(); }
template
diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc
index 16694bcf76..64fb028f83 100644
--- a/paddle/fluid/framework/program_desc.cc
+++ b/paddle/fluid/framework/program_desc.cc
@@ -56,7 +56,7 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
- op->SetBlockAttr(attr.name(), *this->MutableBlock(blk_idx));
+ op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
}
}
}
@@ -73,7 +73,7 @@ ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
for (const auto &attr : op->Proto()->attrs()) {
if (attr.type() == proto::AttrType::BLOCK) {
size_t blk_idx = attr.block_idx();
- op->SetBlockAttr(attr.name(), *this->MutableBlock(blk_idx));
+ op->SetBlockAttr(attr.name(), this->MutableBlock(blk_idx));
}
}
}
diff --git a/paddle/fluid/framework/prune.cc b/paddle/fluid/framework/prune.cc
index 107c5bf8ec..57c1b822d8 100644
--- a/paddle/fluid/framework/prune.cc
+++ b/paddle/fluid/framework/prune.cc
@@ -14,19 +14,19 @@ limitations under the License. */
#include "paddle/fluid/framework/prune.h"
+#include
+
#include
#include
#include
#include
#include
-#include
-
namespace paddle {
namespace framework {
-const std::string kFeedOpType = "feed";
-const std::string kFetchOpType = "fetch";
+const char kFeedOpType[] = "feed";
+const char kFetchOpType[] = "fetch";
bool HasDependentVar(const proto::OpDesc& op_desc,
const std::set& dependent_vars) {
@@ -68,7 +68,7 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
// the child block to help pruning
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id, int parent_block_id,
- std::set& dependent_vars) {
+ std::set* dependent_vars) {
auto& block = input.blocks(block_id);
auto& ops = block.ops();
@@ -90,11 +90,11 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
std::vector should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
- if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) {
+ if (IsTarget(op_desc) || HasDependentVar(op_desc, *dependent_vars)) {
// insert its input to the dependency graph
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
- dependent_vars.insert(argu);
+ dependent_vars->insert(argu);
}
}
should_run.push_back(true);
@@ -138,7 +138,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
- sub_block_dependent_vars);
+ &sub_block_dependent_vars);
}
}
}
@@ -181,7 +181,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
std::set dependent_vars;
output->clear_blocks();
- prune_impl(input, output, 0, -1, dependent_vars);
+ prune_impl(input, output, 0, -1, &dependent_vars);
}
void inference_optimize_impl(proto::ProgramDesc* input, int block_id) {
diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc
index d2e60ab1dd..e5bc74755f 100644
--- a/paddle/fluid/framework/tensor_util.cc
+++ b/paddle/fluid/framework/tensor_util.cc
@@ -20,7 +20,7 @@ namespace paddle {
namespace framework {
void TensorCopy(const Tensor& src, const platform::Place& dst_place,
- const platform::DeviceContext& ctx, Tensor* dst, bool sync) {
+ const platform::DeviceContext& ctx, Tensor* dst) {
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
src.check_memory_size();
@@ -48,9 +48,7 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto ctx_gpu_place = boost::get(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
auto stream =
- sync ? nullptr
- : reinterpret_cast(ctx)
- .stream();
+ reinterpret_cast(ctx).stream();
memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
@@ -61,9 +59,7 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto ctx_gpu_place = boost::get(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
auto stream =
- sync ? nullptr
- : reinterpret_cast(ctx)
- .stream();
+ reinterpret_cast(ctx).stream();
memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream);
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
@@ -72,9 +68,7 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto stream =
- sync ? nullptr
- : reinterpret_cast(ctx)
- .stream();
+ reinterpret_cast(ctx).stream();
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream);
}
#endif
@@ -92,6 +86,41 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
TensorCopy(src, dst_place, *dev_ctx, dst);
}
+void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
+ Tensor* dst) {
+ VLOG(3) << "TensorCopySync " << src.dims() << " from " << src.place()
+ << " to " << dst_place;
+ src.check_memory_size();
+ dst->Resize(src.dims());
+ dst->set_layout(src.layout());
+ auto src_place = src.place();
+ auto src_ptr = src.data();
+ auto dst_ptr = dst->mutable_data(dst_place, src.type());
+ auto size = src.numel() * SizeOfType(src.type());
+ if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
+ memory::Copy(boost::get(dst_place), dst_ptr,
+ boost::get(src_place), src_ptr, size);
+ }
+#ifdef PADDLE_WITH_CUDA
+ else if (platform::is_gpu_place(src_place) && // NOLINT
+ platform::is_cpu_place(dst_place)) {
+ auto src_gpu_place = boost::get(src_place);
+ auto dst_cpu_place = boost::get(dst_place);
+ memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr);
+ } else if (platform::is_cpu_place(src_place) &&
+ platform::is_gpu_place(dst_place)) {
+ auto src_cpu_place = boost::get(src_place);
+ auto dst_gpu_place = boost::get(dst_place);
+ memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, nullptr);
+ } else if (platform::is_gpu_place(src_place) &&
+ platform::is_gpu_place(dst_place)) {
+ auto src_gpu_place = boost::get(src_place);
+ auto dst_gpu_place = boost::get(dst_place);
+ memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr);
+ }
+#endif
+}
+
template
struct AnyDTypeVisitor {
Predicate predicate_;
diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h
index 3af68402dc..dca279b693 100644
--- a/paddle/fluid/framework/tensor_util.h
+++ b/paddle/fluid/framework/tensor_util.h
@@ -24,10 +24,11 @@ namespace paddle {
namespace framework {
void TensorCopy(const Tensor& src, const platform::Place& dst_place,
- const platform::DeviceContext& ctx, Tensor* dst,
- bool sync = false);
+ const platform::DeviceContext& ctx, Tensor* dst);
void TensorCopy(const Tensor& src, const platform::Place& dst_place,
Tensor* dst);
+void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
+ Tensor* dst);
template
void TensorFromVector(const std::vector& src,
diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h
new file mode 100644
index 0000000000..6b0ac92fa9
--- /dev/null
+++ b/paddle/fluid/inference/engine.h
@@ -0,0 +1,52 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include "paddle/fluid/framework/framework.pb.h"
+
+namespace paddle {
+namespace inference {
+
+/*
+ * EngineBase is the base class of all inference engines. An inference engine
+ * takes a paddle program as input, and outputs the result in fluid Tensor
+ * format. It can be used to optimize performance of computation sub-blocks, for
+ * example, break down the original block into sub-blocks and execute each
+ * sub-blocks in different engines.
+ *
+ * For example:
+ * When inference, the resnet50 model can put most of the model into subgraph
+ * and run it on a TensorRT engine.
+ *
+ * There are several engines such as TensorRT and other frameworks, so an
+ * EngineBase is put forward to give an unified interface for all the
+ * different engine implemention.
+ */
+class EngineBase {
+ public:
+ using DescType = ::paddle::framework::proto::BlockDesc;
+
+ // Build the model and do some preparation, for example, in TensorRT, run
+ // createInferBuilder, buildCudaEngine.
+ virtual void Build(const DescType& paddle_model) = 0;
+
+ // Execute the engine, that will run the inference network.
+ virtual void Execute(int batch_size) = 0;
+
+ virtual ~EngineBase() {}
+}; // class EngineBase
+
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/io.cc b/paddle/fluid/inference/io.cc
index 78d2f16746..65db7c7b50 100644
--- a/paddle/fluid/inference/io.cc
+++ b/paddle/fluid/inference/io.cc
@@ -16,17 +16,29 @@ limitations under the License. */
#include
#include
+#include
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/pybind/pybind.h"
+DEFINE_string(devices, "", "The devices to be used which is joined by comma.");
+DEFINE_bool(init_p2p, false, "Whether to init p2p.");
+
namespace paddle {
namespace inference {
-// Temporarily add this function for exposing framework::InitDevices() when
-// linking the inference shared library.
-void Init(bool init_p2p) { framework::InitDevices(init_p2p); }
+void Init(const std::vector argv) {
+ framework::InitGflags(argv);
+ // init devices
+ std::vector devices;
+ std::string token;
+ std::istringstream tokenStream(FLAGS_devices);
+ while (std::getline(tokenStream, token, ',')) {
+ devices.push_back(std::stoi(token));
+ }
+ framework::InitDevices(FLAGS_init_p2p, devices);
+}
void ReadBinaryFile(const std::string& filename, std::string* contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
diff --git a/paddle/fluid/inference/io.h b/paddle/fluid/inference/io.h
index ba3e45099a..caf599b1a6 100644
--- a/paddle/fluid/inference/io.h
+++ b/paddle/fluid/inference/io.h
@@ -25,7 +25,7 @@ limitations under the License. */
namespace paddle {
namespace inference {
-void Init(bool init_p2p);
+void Init(const std::vector argv);
void LoadPersistables(framework::Executor* executor, framework::Scope* scope,
const framework::ProgramDesc& main_program,
diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt
index e39c0daac7..4b5866ad5d 100644
--- a/paddle/fluid/inference/tensorrt/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt
@@ -1 +1,4 @@
-nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
+if(WITH_TESTING)
+ nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
+ nv_test(test_tensorrt_engine SRCS test_engine.cc engine.cc DEPS dynload_cuda)
+endif()
diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc
new file mode 100644
index 0000000000..03a25f8e8b
--- /dev/null
+++ b/paddle/fluid/inference/tensorrt/engine.cc
@@ -0,0 +1,135 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/inference/tensorrt/engine.h"
+
+#include
+#include
+#include
+#include
+#include "paddle/fluid/inference/tensorrt/helper.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace tensorrt {
+
+void TensorRTEngine::Build(const DescType& paddle_model) {
+ PADDLE_ENFORCE(false, "not implemented");
+}
+
+void TensorRTEngine::Execute(int batch_size) {
+ infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr);
+ cudaStreamSynchronize(*stream_);
+}
+
+TensorRTEngine::~TensorRTEngine() {
+ // clean buffer
+ for (auto& buffer : buffers_) {
+ if (buffer != nullptr) {
+ PADDLE_ENFORCE_EQ(0, cudaFree(buffer));
+ buffer = nullptr;
+ }
+ }
+}
+
+void TensorRTEngine::FreezeNetwork() {
+ PADDLE_ENFORCE(infer_builder_ != nullptr,
+ "Call InitNetwork first to initialize network.");
+ PADDLE_ENFORCE(infer_network_ != nullptr,
+ "Call InitNetwork first to initialize network.");
+ // build engine.
+ infer_builder_->setMaxBatchSize(max_batch_);
+ infer_builder_->setMaxWorkspaceSize(max_workspace_);
+
+ infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
+ PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
+
+ infer_context_.reset(infer_engine_->createExecutionContext());
+
+ // allocate GPU buffers.
+ buffers_.resize(buffer_sizes_.size(), nullptr);
+ for (auto& item : buffer_sizes_) {
+ if (item.second == 0) {
+ auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
+ item.second = kDataTypeSize[static_cast(
+ infer_engine_->getBindingDataType(slot_offset))] *
+ AccumDims(infer_engine_->getBindingDimensions(slot_offset));
+ }
+ PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second));
+ }
+}
+
+nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
+ nvinfer1::DataType dtype,
+ const nvinfer1::Dims& dim) {
+ PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
+ name);
+
+ PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
+ auto* input = infer_network_->addInput(name.c_str(), dtype, dim);
+ PADDLE_ENFORCE(input, "infer network add input %s failed", name);
+
+ buffer_sizes_[name] = kDataTypeSize[static_cast(dtype)] * AccumDims(dim);
+ return input;
+}
+
+void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset,
+ const std::string& name) {
+ PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s",
+ name);
+
+ auto* output = layer->getOutput(offset);
+ PADDLE_ENFORCE(output != nullptr);
+ output->setName(name.c_str());
+ infer_network_->markOutput(*output);
+ // output buffers' size can only be decided latter, set zero here to mark this
+ // and will reset latter.
+ buffer_sizes_[name] = 0;
+}
+
+void* TensorRTEngine::GetOutputInGPU(const std::string& name) {
+ return buffer(name);
+}
+
+void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst,
+ size_t max_size) {
+ // determine data size
+ auto it = buffer_sizes_.find(name);
+ PADDLE_ENFORCE(it != buffer_sizes_.end());
+ PADDLE_ENFORCE_GT(it->second, 0);
+ PADDLE_ENFORCE_GE(max_size, it->second);
+
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second,
+ cudaMemcpyDeviceToHost, *stream_));
+}
+
+void*& TensorRTEngine::buffer(const std::string& name) {
+ PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first.");
+ auto it = buffer_sizes_.find(name);
+ PADDLE_ENFORCE(it != buffer_sizes_.end());
+ auto slot_offset = infer_engine_->getBindingIndex(name.c_str());
+ return buffers_[slot_offset];
+}
+
+void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
+ size_t size) {
+ void* buf = buffer(name);
+ PADDLE_ENFORCE_EQ(
+ 0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_));
+}
+
+} // namespace tensorrt
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h
new file mode 100644
index 0000000000..82d8c3df4e
--- /dev/null
+++ b/paddle/fluid/inference/tensorrt/engine.h
@@ -0,0 +1,146 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+#include "paddle/fluid/inference/engine.h"
+#include "paddle/fluid/inference/tensorrt/helper.h"
+
+namespace paddle {
+namespace inference {
+namespace tensorrt {
+
+/*
+ * TensorRT Engine.
+ *
+ * There are two alternative ways to use it, one is to build from a paddle
+ * protobuf model, another way is to manully construct the network.
+ */
+class TensorRTEngine : public EngineBase {
+ public:
+ // Weight is model parameter.
+ class Weight {
+ public:
+ Weight(nvinfer1::DataType dtype, void* value, int num_elem) {
+ w_.type = dtype;
+ w_.values = value;
+ w_.count = num_elem;
+ }
+ const nvinfer1::Weights& get() { return w_; }
+
+ private:
+ nvinfer1::Weights w_;
+ };
+
+ TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream,
+ nvinfer1::ILogger& logger = NaiveLogger::Global())
+ : max_batch_(max_batch),
+ max_workspace_(max_workspace),
+ stream_(stream),
+ logger_(logger) {}
+
+ virtual ~TensorRTEngine();
+
+ // TODO(Superjomn) implement it later when graph segmentation is supported.
+ void Build(const DescType& paddle_model) override;
+
+ void Execute(int batch_size) override;
+
+ // Initialize the inference network, so that TensorRT layers can add to this
+ // network.
+ void InitNetwork() {
+ infer_builder_.reset(createInferBuilder(logger_));
+ infer_network_.reset(infer_builder_->createNetwork());
+ }
+ // After finishing adding ops, freeze this network and creates the executation
+ // environment.
+ void FreezeNetwork();
+
+ // Add an input and set its name, data type and dimention.
+ nvinfer1::ITensor* DeclareInput(const std::string& name,
+ nvinfer1::DataType dtype,
+ const nvinfer1::Dims& dim);
+ // Set the offset-th output from a layer as the network's output, and set its
+ // name.
+ void DeclareOutput(const nvinfer1::ILayer* layer, int offset,
+ const std::string& name);
+
+ // GPU memory address for an ITensor with specific name. One can operate on
+ // these memory directly for acceleration, for example, output the converted
+ // data directly to the buffer to save data copy overhead.
+ // NOTE this should be used after calling `FreezeNetwork`.
+ void*& buffer(const std::string& name);
+
+ // Fill an input from CPU memory with name and size.
+ void SetInputFromCPU(const std::string& name, void* data, size_t size);
+ // TODO(Superjomn) is this method necessary given that buffer(xxx) can be
+ // accessed directly. Fill an input from GPU memory with name and size.
+ void SetInputFromGPU(const std::string& name, void* data, size_t size);
+ // Get an output called name, the output of tensorrt is in GPU, so this method
+ // will just return the output's GPU memory address.
+ void* GetOutputInGPU(const std::string& name);
+ // LOW EFFICENCY! Get output to CPU, this will trigger a memory copy from GPU
+ // to CPU.
+ void GetOutputInCPU(const std::string& name, void* dst, size_t max_size);
+
+ nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); }
+ nvinfer1::INetworkDefinition* network() { return infer_network_.get(); }
+
+ private:
+ // the max batch size
+ int max_batch_;
+ // the max memory size the engine uses
+ int max_workspace_;
+ cudaStream_t* stream_;
+ nvinfer1::ILogger& logger_;
+
+ std::vector buffers_;
+ // max data size for the buffers.
+ std::unordered_map buffer_sizes_;
+
+ // TensorRT related internal members
+ template
+ struct Destroyer {
+ void operator()(T* x) { x->destroy(); }
+ };
+ template
+ using infer_ptr = std::unique_ptr>;
+ infer_ptr infer_builder_;
+ infer_ptr infer_network_;
+ infer_ptr infer_engine_;
+ infer_ptr infer_context_;
+}; // class TensorRTEngine
+
+// Add an layer__ into engine__ with args ARGS.
+// For example:
+// TRT_ENGINE_ADD_LAYER(xxx, FullyConnected, input, dim, weights, bias)
+//
+// Reference
+// https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html#charRNN_define_network
+//
+// will add a fully connected layer into the engine.
+// TensorRT has too many layers, so that is not wise to add member functions for
+// them, and an macro like this is more extensible when underlying TensorRT
+// library add new layer supports.
+#define TRT_ENGINE_ADD_LAYER(engine__, layer__, ARGS...) \
+ engine__->network()->add##layer__(ARGS);
+
+} // namespace tensorrt
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h
new file mode 100644
index 0000000000..796283d325
--- /dev/null
+++ b/paddle/fluid/inference/tensorrt/helper.h
@@ -0,0 +1,88 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+#include "paddle/fluid/platform/dynload/tensorrt.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace tensorrt {
+
+namespace dy = paddle::platform::dynload;
+
+static size_t AccumDims(nvinfer1::Dims dims) {
+ size_t num = dims.nbDims == 0 ? 0 : 1;
+ for (int i = 0; i < dims.nbDims; i++) {
+ PADDLE_ENFORCE_GT(dims.d[i], 0);
+ num *= dims.d[i];
+ }
+ return num;
+}
+
+// TensorRT data type to size
+const int kDataTypeSize[] = {
+ 4, // kFLOAT
+ 2, // kHALF
+ 1, // kINT8
+ 4 // kINT32
+};
+
+// The following two API are implemented in TensorRT's header file, cannot load
+// from the dynamic library. So create our own implementation and directly
+// trigger the method from the dynamic library.
+static nvinfer1::IBuilder* createInferBuilder(nvinfer1::ILogger& logger) {
+ return static_cast(
+ dy::createInferBuilder_INTERNAL(&logger, NV_TENSORRT_VERSION));
+}
+static nvinfer1::IRuntime* createInferRuntime(nvinfer1::ILogger& logger) {
+ return static_cast(
+ dy::createInferRuntime_INTERNAL(&logger, NV_TENSORRT_VERSION));
+}
+
+// A logger for create TensorRT infer builder.
+class NaiveLogger : public nvinfer1::ILogger {
+ public:
+ void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
+ switch (severity) {
+ case Severity::kINFO:
+ LOG(INFO) << msg;
+ break;
+ case Severity::kWARNING:
+ LOG(WARNING) << msg;
+ break;
+ case Severity::kINTERNAL_ERROR:
+ case Severity::kERROR:
+ LOG(ERROR) << msg;
+ break;
+ default:
+ break;
+ }
+ }
+
+ static nvinfer1::ILogger& Global() {
+ static nvinfer1::ILogger* x = new NaiveLogger;
+ return *x;
+ }
+
+ virtual ~NaiveLogger() override {}
+};
+
+} // namespace tensorrt
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/tensorrt/test_engine.cc b/paddle/fluid/inference/tensorrt/test_engine.cc
new file mode 100644
index 0000000000..c6e1c71cdc
--- /dev/null
+++ b/paddle/fluid/inference/tensorrt/test_engine.cc
@@ -0,0 +1,83 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include
+#include
+
+#include "paddle/fluid/inference/tensorrt/engine.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace tensorrt {
+
+class TensorRTEngineTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ ASSERT_EQ(0, cudaStreamCreate(&stream_));
+ engine_ = new TensorRTEngine(1, 1 << 10, &stream_);
+ engine_->InitNetwork();
+ }
+
+ void TearDown() override {
+ delete engine_;
+ cudaStreamDestroy(stream_);
+ }
+
+ protected:
+ TensorRTEngine* engine_;
+ cudaStream_t stream_;
+};
+
+TEST_F(TensorRTEngineTest, add_layer) {
+ const int size = 1;
+
+ float raw_weight[size] = {2.}; // Weight in CPU memory.
+ float raw_bias[size] = {3.};
+
+ LOG(INFO) << "create weights";
+ TensorRTEngine::Weight weight(nvinfer1::DataType::kFLOAT, raw_weight, size);
+ TensorRTEngine::Weight bias(nvinfer1::DataType::kFLOAT, raw_bias, size);
+ auto* x = engine_->DeclareInput("x", nvinfer1::DataType::kFLOAT,
+ nvinfer1::DimsCHW{1, 1, 1});
+ auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *x, size,
+ weight.get(), bias.get());
+ PADDLE_ENFORCE(fc_layer != nullptr);
+
+ engine_->DeclareOutput(fc_layer, 0, "y");
+ LOG(INFO) << "freeze network";
+ engine_->FreezeNetwork();
+ ASSERT_EQ(engine_->engine()->getNbBindings(), 2);
+
+ // fill in real data
+ float x_v = 1234;
+ engine_->SetInputFromCPU("x", reinterpret_cast(&x_v),
+ 1 * sizeof(float));
+ LOG(INFO) << "to execute";
+ engine_->Execute(1);
+
+ LOG(INFO) << "to get output";
+ // void* y_v =
+ float y_cpu;
+ engine_->GetOutputInCPU("y", &y_cpu, sizeof(float));
+
+ LOG(INFO) << "to checkout output";
+ ASSERT_EQ(y_cpu, x_v * 2 + 3);
+}
+
+} // namespace tensorrt
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/tensorrt/test_tensorrt.cc b/paddle/fluid/inference/tensorrt/test_tensorrt.cc
index a81a708e7a..aed5b5e1a2 100644
--- a/paddle/fluid/inference/tensorrt/test_tensorrt.cc
+++ b/paddle/fluid/inference/tensorrt/test_tensorrt.cc
@@ -1,16 +1,16 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+http://www.apache.org/licenses/LICENSE-2.0
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License. */
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
#include
#include
diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
index 1e6555bb02..1a685b9e2e 100644
--- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
+++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
@@ -62,5 +62,21 @@ TEST(inference, image_classification) {
LOG(INFO) << output2.dims();
CheckError(output1, output2);
+
+ // float16 inference requires cuda GPUs with >= 5.3 compute capability
+ if (paddle::platform::GetCUDAComputeCapability(0) >= 53) {
+ paddle::framework::LoDTensor output3;
+ std::vector cpu_fetchs3;
+ cpu_fetchs3.push_back(&output3);
+
+ LOG(INFO) << "--- GPU Runs in float16 mode: ---";
+ std::string fp16_dirname = dirname;
+ fp16_dirname.replace(fp16_dirname.find("book/"),
+ std::string("book/").size(), "book/float16_");
+ TestInference(
+ fp16_dirname, cpu_feeds, cpu_fetchs3, FLAGS_repeat);
+
+ CheckError(output2, output3);
+ }
#endif
}
diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h
index 117472599f..af2a7a5620 100644
--- a/paddle/fluid/inference/tests/test_helper.h
+++ b/paddle/fluid/inference/tests/test_helper.h
@@ -178,10 +178,10 @@ void TestInference(const std::string& dirname,
std::unique_ptr ctx;
if (PrepareContext) {
ctx = executor.Prepare(*inference_program, 0);
- executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
- CreateVars);
+ executor.RunPreparedContext(ctx.get(), scope, &feed_targets,
+ &fetch_targets, CreateVars);
} else {
- executor.Run(*inference_program, scope, feed_targets, fetch_targets,
+ executor.Run(*inference_program, scope, &feed_targets, &fetch_targets,
CreateVars);
}
@@ -197,10 +197,10 @@ void TestInference(const std::string& dirname,
if (PrepareContext) {
// Note: if you change the inference_program, you need to call
// executor.Prepare() again to get a new ExecutorPrepareContext.
- executor.RunPreparedContext(ctx.get(), scope, feed_targets,
- fetch_targets, CreateVars);
+ executor.RunPreparedContext(ctx.get(), scope, &feed_targets,
+ &fetch_targets, CreateVars);
} else {
- executor.Run(*inference_program, scope, feed_targets, fetch_targets,
+ executor.Run(*inference_program, scope, &feed_targets, &fetch_targets,
CreateVars);
}
}
diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h
index b332b67163..f82ff47b52 100644
--- a/paddle/fluid/operators/adam_op.h
+++ b/paddle/fluid/operators/adam_op.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include // for sqrt in CPU and CUDA
+#include
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
@@ -24,8 +25,14 @@ namespace operators {
namespace scatter = paddle::operators::math::scatter;
+struct GPUAdam;
+struct CPUAdam;
+
+template
+struct AdamFunctor;
+
template
-struct AdamFunctor {
+struct AdamFunctor {
T beta1_;
T beta2_;
T epsilon_;
@@ -71,6 +78,7 @@ struct AdamFunctor {
// Calculation
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
+
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
@@ -82,6 +90,71 @@ struct AdamFunctor {
}
};
+template
+struct AdamFunctor {
+ T beta1_;
+ T beta2_;
+ T epsilon_;
+
+ const T* beta1_pow_;
+ const T* beta2_pow_;
+ const T* moment1_;
+ T* moment1_out_;
+ const T* moment2_;
+ T* moment2_out_;
+ const T* lr_;
+ const T* grad_;
+ const T* param_;
+ T* param_out_;
+
+ AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
+ const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
+ T* mom2_out, const T* lr, const T* grad, const T* param,
+ T* param_out)
+ : beta1_(beta1),
+ beta2_(beta2),
+ epsilon_(epsilon),
+ beta1_pow_(beta1_pow),
+ beta2_pow_(beta2_pow),
+ moment1_(mom1),
+ moment1_out_(mom1_out),
+ moment2_(mom2),
+ moment2_out_(mom2_out),
+ lr_(lr),
+ grad_(grad),
+ param_(param),
+ param_out_(param_out) {}
+
+ void operator()(size_t numel) const {
+ Eigen::Map> g{
+ grad_, static_cast(numel)};
+ Eigen::Map> mom1{
+ moment1_, static_cast(numel)};
+ Eigen::Map> mom2{
+ moment2_, static_cast(numel)};
+ Eigen::Map> param{
+ param_, static_cast(numel)};
+
+ Eigen::Map> param_out{
+ param_out_, static_cast(numel)};
+ Eigen::Map> moment1_out{
+ moment1_out_, static_cast(numel)};
+ Eigen::Map> moment2_out{
+ moment2_out_, static_cast(numel)};
+
+ T lr = *lr_;
+ T beta1_pow = *beta1_pow_;
+ T beta2_pow = *beta2_pow_;
+
+ // Calculation
+ lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
+
+ moment1_out = beta1_ * mom1 + (1 - beta1_) * g;
+ moment2_out = beta2_ * mom2 + (1 - beta2_) * g * g;
+ param_out = param - lr * (moment1_out / (moment2_out.sqrt() + epsilon_));
+ }
+};
+
template
struct SparseAdamFunctor {
T beta1_;
@@ -134,6 +207,7 @@ struct SparseAdamFunctor {
T p = param_[rows_[i] * row_numel_ + j];
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
+
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
@@ -177,19 +251,34 @@ class AdamOpKernel : public framework::OpKernel {
if (grad_var->IsType()) {
auto& grad = Ref(ctx.Input("Grad"), "Must set Grad");
- AdamFunctor functor(
- beta1, beta2, epsilon, beta1_pow.template data(),
- beta2_pow.template data(), mom1.template data(),
- mom1_out.template mutable_data(ctx.GetPlace()),
- mom2.template data(),
- mom2_out.template mutable_data(ctx.GetPlace()),
- lr.template data(), grad.template data(),
- param.template data(),
- param_out.template mutable_data(ctx.GetPlace()));
- platform::ForRange for_range(
- static_cast(ctx.device_context()),
- param.numel());
- for_range(functor);
+
+ if (platform::is_cpu_place(ctx.GetPlace())) {
+ AdamFunctor functor(
+ beta1, beta2, epsilon, beta1_pow.template data(),
+ beta2_pow.template data(), mom1.template data(),
+ mom1_out.template mutable_data(ctx.GetPlace()),
+ mom2.template data(),
+ mom2_out.template mutable_data(ctx.GetPlace()),
+ lr.template data(), grad.template data(),
+ param.template data(),
+ param_out.template mutable_data(ctx.GetPlace()));
+ functor(param.numel());
+ } else if (platform::is_gpu_place(ctx.GetPlace())) {
+ AdamFunctor functor(
+ beta1, beta2, epsilon, beta1_pow.template data(),
+ beta2_pow.template data(), mom1.template data(),
+ mom1_out.template mutable_data(ctx.GetPlace()),
+ mom2.template data(),
+ mom2_out.template mutable_data(ctx.GetPlace()),
+ lr.template data(), grad.template data(),
+ param.template data(),
+ param_out.template mutable_data(ctx.GetPlace()));
+
+ platform::ForRange for_range(
+ static_cast(ctx.device_context()),
+ param.numel());
+ for_range(functor);
+ }
} else if (grad_var->IsType()) {
auto& grad =
Ref(ctx.Input("Grad"), "Must set Grad");
diff --git a/paddle/fluid/operators/beam_search_decode_op.h b/paddle/fluid/operators/beam_search_decode_op.h
index 4cb0457d92..3c01f81c83 100644
--- a/paddle/fluid/operators/beam_search_decode_op.h
+++ b/paddle/fluid/operators/beam_search_decode_op.h
@@ -223,8 +223,9 @@ void BeamSearchDecoder::ConvertSentenceVectorToLodTensor(
sentence_vector_list[src_idx].size());
}
- auto cpu_place = new paddle::platform::CPUPlace();
- paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place);
+ auto cpu_place = std::unique_ptr(
+ new paddle::platform::CPUPlace());
+ paddle::platform::CPUDeviceContext cpu_ctx(*cpu_place.get());
framework::LoD lod;
lod.push_back(source_level_lod);
diff --git a/paddle/fluid/operators/bilinear_interp_op.cc b/paddle/fluid/operators/bilinear_interp_op.cc
new file mode 100644
index 0000000000..69f79bf93b
--- /dev/null
+++ b/paddle/fluid/operators/bilinear_interp_op.cc
@@ -0,0 +1,94 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/fluid/operators/bilinear_interp_op.h"
+#include
+#include "paddle/fluid/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using framework::Tensor;
+
+class BilinearInterpOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(framework::InferShapeContext* ctx) const override {
+ PADDLE_ENFORCE(ctx->HasInput("X"),
+ "Input(X) of BilinearInterOp should not be null.");
+ PADDLE_ENFORCE(ctx->HasOutput("Out"),
+ "Output(Out) of BilinearInterOp should not be null.");
+
+ auto dim_x = ctx->GetInputDim("X"); // NCHW format
+ int out_h = ctx->Attrs().Get("out_h");
+ int out_w = ctx->Attrs().Get("out_w");
+ PADDLE_ENFORCE_EQ(dim_x.size(), 4, "X's dimension must be 4");
+
+ std::vector