Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into develop

revert-3824-remove_grad_op_type
zchen0211 8 years ago
commit ee29d1b6e2

@ -0,0 +1,15 @@
#!/bin/bash
set -e
readonly VERSION="3.8"
version=$(clang-format -version)
if ! [[ $version == *"$VERSION"* ]]; then
echo "clang-format version check failed."
echo "a version contains '$VERSION' is needed, but get '$version'"
echo "you can install the right version, and make an soft-link to '\$PATH' env"
exit -1
fi
clang-format $@

@ -19,10 +19,10 @@
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: local - repo: local
hooks: hooks:
- id: clang-format - id: clang-format-with-version-check
name: clang-format name: clang-format
description: Format files with ClangFormat. description: Format files with ClangFormat.
entry: clang-format -i entry: ./.clang_format.hook -i
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang - repo: https://github.com/PaddlePaddle/pre-commit-golang

@ -71,20 +71,6 @@ RUN pip install -r /root/requirements.txt
RUN apt-get install -y libssl-dev libffi-dev RUN apt-get install -y libssl-dev libffi-dev
RUN pip install certifi urllib3[secure] RUN pip install certifi urllib3[secure]
# TODO(qijun) The template library Eigen doesn't work well with GCC 5
# coming with the default Docker image, so we switch to use GCC 4.8
# by default. And I will check Eigen library later.
RUN ln -sf gcc-4.8 /usr/bin/gcc && \
ln -sf gcc-ar-4.8 /usr/bin/gcc-ar && \
ln -sf gcc-nm-4.8 /usr/bin/gcc-nm && \
ln -sf gcc-ranlib-4.8 /usr/bin/gcc-ranlib && \
ln -sf gcc-4.8 /usr/bin/x86_64-linux-gnu-gcc && \
ln -sf gcc-ar-4.8 /usr/bin/x86_64-linux-gnu-gcc-ar && \
ln -sf gcc-nm-4.8 /usr/bin/x86_64-linux-gnu-gcc-nm && \
ln -sf gcc-ranlib-4.8 /usr/bin/x86_64-linux-gnu-gcc-ranlib && \
ln -sf g++-4.8 /usr/bin/g++ && \
ln -sf g++-4.8 /usr/bin/x86_64-linux-gnu-g++
# Install woboq_codebrowser to /woboq # Install woboq_codebrowser to /woboq
RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \ RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \

@ -9,13 +9,6 @@ function(CheckCompilerCXX11Flag)
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif() endif()
if(NOT ANDROID)
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9)
set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE)
endif()
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers. # Apple Clang is a different compiler than upstream Clang which havs different version numbers.
@ -160,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here. # So, don't set these flags here.
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread) LIST(APPEND CUDA_NVCC_FLAGS -std=c++11)
LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math)
if(CMAKE_BUILD_TYPE STREQUAL "Debug") if(CMAKE_BUILD_TYPE STREQUAL "Debug")

@ -0,0 +1,101 @@
# Alalysis of large model distributed training in Paddle
***NOTE: This is only some note for how we implemeted this scheme in V1, not a new design.***
## What is it
We often encounter cases that the embedding layer parameters(sparse) are so large that we can not store it in the trainer's memory when training. So we need to put them to several servers, and fetch them row by row instead of fetch all of the parameters.
## How to use
Specify command-line argument like `--loadsave_parameters_in_pserver=true --ports_num_for_sparse=1 --use_old_updater=1` when starting the paddle trainer. And also add something like `--ports_num_for_sparse=1 --pserver_num_threads=5` when starting pserver processes.
Accrodingly, configure your embedding layers like:
```python
SPARSE_REMOTE=True
w1 = data_layer(name="w1", size=dict_size)
emb1 = embedding_layer(input=w1, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE))
w2 = data_layer(name="w2", size=dict_size)
emb2 = embedding_layer(input=w2, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE))
...
```
## Implementation details
```c++
enum MatType {
MAT_NORMAL,
MAT_NORMAL_SHARED,
MAT_VALUE_SHARED,
MAT_SPARSE_ROW_IDS,
MAT_SPARSE_ROW_AUTO_GROW,
MAT_CACHE_ROW,
MAT_SPARSE_ROW,
MAT_SPARSE_ROW_PREFETCH,
MAT_SPARSE_ROW_PREFETCH_FULL_SIZE,
};
```
`MAT_SPARSE_ROW_PREFETCH` is what we use when configured to fetch only row of matrix when training.
In `trainer_internal.cpp:L93 trainOneBatch`:
```c++
if (config_->getOptConfig().use_sparse_remote_updater()) {
REGISTER_TIMER("prefetch");
gradientMachine_->prefetch(inArgs);
parameterUpdater_->getParametersRemote();
}
```
When doing actual network forward and backward, at the beginning of each batch, the trainer will try to download one row of data from pserver.
In `trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`:
```c++
if (fullSize) {
...
} else {
getParams = [&] {
parameterClient_->getParameterSparse(
/* recvParameterType= */ PARAMETER_VALUE, sendBackParameterType);
};
applyL1 = [](Parameter& para, real decayRate) {
para.getMat(PARAMETER_VALUE)->applyL1(/*lr=*/1.0f, decayRate);
};
}
```
Calling `parameterClient_->getParameterSparse` will do remote call to pserver's `getParameterSparse`:
```c++
void ParameterServer2::getParameterSparse(const SendParameterRequest& request,
std::vector<Buffer>& inputBuffers,
SendParameterResponse* response,
std::vector<Buffer>* outputBuffers) {
(void)inputBuffers;
auto& buffer = *readWriteBuffer_;
size_t numReals = 0;
for (const auto& block : request.blocks()) {
numReals += getParameterConfig(block).dims(1);
}
buffer.resize(numReals);
VLOG(3) << "pserver: getParameterSparse, numReals=" << numReals;
ReadLockGuard guard(parameterMutex_);
size_t offset = 0;
for (const auto& block : request.blocks()) {
size_t width = getParameterConfig(block).dims(1);
Buffer buf = {buffer.data() + offset, width};
int type = request.send_back_parameter_type();
sendBackParameterSparse(block, type, response, &buf, width, outputBuffers);
offset += width;
}
}
```
`getParameterConfig(block).dims(1)` returns the width of the current "parameter block"(a shard of parameter object),
then `getParameterSparse` remote call returns only one row of data to the client.

@ -101,6 +101,7 @@ if use_mkldnn
5. 在**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue`和`mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。 5. 在**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue`和`mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。
6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`并针对device在MKL-DNN和CPU之间不统一的情况做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。 6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`并针对device在MKL-DNN和CPU之间不统一的情况做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。
7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag用于选择是否使用MKL-DNN的相关功能。 7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag用于选择是否使用MKL-DNN的相关功能。
8. 关于MKLDNN参数的保存。由于MKLDNN参数的格式与PaddlePaddle原有的格式存在不一样的情况所以需要在保存参数时同时保存该格式信息。目前准备扩展[Header](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/parameter/Parameter.h#L247)里面的`int32_t version`。这个值不管是在v1还是在v2里面一直保存的是0所以可以充分利用这个信息定义一个枚举处理所有MKLDNN的参数格式从而`MKLDNNLayer`就可以从输入的参数中获取需要的格式信息。
## References ## References

@ -68,7 +68,7 @@ As a simple example, consider the following:
1. **BLAS Dependencies(optional)** 1. **BLAS Dependencies(optional)**
CMake will search BLAS libraries from system. If not found, OpenBLAS will be downloaded, built and installed automatically. CMake will search BLAS libraries from the system. If not found, OpenBLAS will be downloaded, built and installed automatically.
To utilize preinstalled BLAS you can simply specify MKL, OpenBLAS or ATLAS via `MKL_ROOT`, `OPENBLAS_ROOT` or `ATLAS_ROOT`. To utilize preinstalled BLAS you can simply specify MKL, OpenBLAS or ATLAS via `MKL_ROOT`, `OPENBLAS_ROOT` or `ATLAS_ROOT`.
```bash ```bash
@ -131,9 +131,9 @@ As a simple example, consider the following:
To build GPU version, you will need the following installed: To build GPU version, you will need the following installed:
1. a CUDA-capable GPU 1. a CUDA-capable GPU
2. A supported version of Linux with a gcc compiler and toolchain 2. A supported version of Linux with a GCC compiler and toolchain
3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads) 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn) 4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn)
The CUDA development environment relies on tight integration with the host development environment, The CUDA development environment relies on tight integration with the host development environment,
including the host compiler and C runtime libraries, and is therefore only supported on including the host compiler and C runtime libraries, and is therefore only supported on
@ -172,6 +172,7 @@ export PATH=<path to install>/bin:$PATH
# install PaddlePaddle Python modules. # install PaddlePaddle Python modules.
sudo pip install <path to install>/opt/paddle/share/wheels/*.whl sudo pip install <path to install>/opt/paddle/share/wheels/*.whl
``` ```
## <span id="centos">Build on Centos 7</span> ## <span id="centos">Build on Centos 7</span>
### Install Dependencies ### Install Dependencies
@ -192,9 +193,9 @@ sudo pip install <path to install>/opt/paddle/share/wheels/*.whl
To build GPU version, you will need the following installed: To build GPU version, you will need the following installed:
1. a CUDA-capable GPU 1. a CUDA-capable GPU
2. A supported version of Linux with a gcc compiler and toolchain 2. A supported version of Linux with a GCC compiler and toolchain
3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads) 3. NVIDIA CUDA Toolkit (available at http://developer.nvidia.com/cuda-downloads)
4. NVIDIA cuDNN Library (availabel at https://developer.nvidia.com/cudnn) 4. NVIDIA cuDNN Library (available at https://developer.nvidia.com/cudnn)
The CUDA development environment relies on tight integration with the host development environment, The CUDA development environment relies on tight integration with the host development environment,
including the host compiler and C runtime libraries, and is therefore only supported on including the host compiler and C runtime libraries, and is therefore only supported on

@ -146,3 +146,19 @@ paddle_error paddle_gradient_machine_randomize_param(
m->machine->randParameters(); m->machine->randParameters();
return kPD_NO_ERROR; return kPD_NO_ERROR;
} }
paddle_error paddle_gradient_machine_get_layer_output(
paddle_gradient_machine machine,
const char* layerName,
paddle_arguments args) {
auto m = cast(machine);
auto out = paddle::capi::cast<paddle::capi::CArguments>(args);
if (m == nullptr || layerName == nullptr || out == nullptr ||
m->machine == nullptr) {
return kPD_NULLPTR;
}
auto layerOutput = m->machine->getLayerOutput(layerName);
out->args.push_back(layerOutput);
return kPD_NO_ERROR;
}

@ -39,7 +39,11 @@ PD_API paddle_error paddle_gradient_machine_create_for_inference(
/** /**
* @brief Create a gradient machine used for model inference, using config with * @brief Create a gradient machine used for model inference, using config with
* parameters which is generated by `paddle merge_model`. * parameters which is generated by `paddle merge_model`.
* @param [out] machine that used for model inference. * Example:
* paddle merge_model \
* --model_dir="pass-00000" \
* --model_file="merged_model.paddle"
* @param [out] machine that used for model inference
* @param [in] mergedModel * @param [in] mergedModel
* @param [in] size * @param [in] size
* @return paddle_error * @return paddle_error
@ -97,6 +101,18 @@ paddle_gradient_machine_randomize_param(paddle_gradient_machine machine);
PD_API paddle_error PD_API paddle_error
paddle_gradient_machine_destroy(paddle_gradient_machine machine); paddle_gradient_machine_destroy(paddle_gradient_machine machine);
/**
* @brief Get the output of the layer named `layerName`.
* @param [in] gradient machine that have run a inference
* @param [in] layerName name of specified layer
* @param [out] args output of the specified layer
* @return paddle_error
*/
PD_API paddle_error
paddle_gradient_machine_get_layer_output(paddle_gradient_machine machine,
const char* layerName,
paddle_arguments args);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

@ -40,7 +40,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
cc_library(backward SRCS backward.cc DEPS net_op) cc_library(backward SRCS backward.cc DEPS net_op)
cc_test(backward_test SRCS backward_test.cc DEPS backward) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context)
if(WITH_PYTHON) if(WITH_PYTHON)
cc_library(paddle_pybind SHARED cc_library(paddle_pybind SHARED

@ -17,6 +17,7 @@
#include <list> #include <list>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -178,6 +179,22 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
return false; return false;
}); });
// process recurrent gradient op as a special operator.
if (forwardOp.Type() == "recurrent_op") {
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
// this will result in infinite loop.
const auto& rnnop =
*static_cast<const operators::RecurrentOp*>(&forwardOp);
auto rnn_grad_op =
static_cast<operators::RecurrentGradientOp*>(grad_op.get());
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
rnn_grad_op->set_stepnet(
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
}
if (net->ops_.empty()) { // Current no aux op is added to network if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op; return grad_op;
} }

@ -17,5 +17,49 @@ limitations under the License. */
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace framework {} // namespace framework namespace framework {
std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs) {
auto it = op_info_map().find(type);
PADDLE_ENFORCE(it != op_info_map().end(),
"Operator '%s' has not been registered.", type);
it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op);
}
std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr);
}
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}
OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) {
VarNameMap ret_val;
for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()];
auto& var_names_in_proto = var.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}
return ret_val;
}
std::shared_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
return grad_op;
}
} // namespace framework
} // namespace paddle } // namespace paddle

@ -29,103 +29,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
protected:
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name,
const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return VariableBuilder{input};
}
VariableBuilder AddOutput(const std::string& name,
const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return VariableBuilder{output};
}
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}
OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
class OpRegistry { class OpRegistry {
using VarNameMap = OperatorBase::VarNameMap; using VarNameMap = OperatorBase::VarNameMap;
using OpCreator = std::function<OperatorBase*( using OpCreator = std::function<OperatorBase*(
@ -177,45 +80,14 @@ class OpRegistry {
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs, const VarNameMap& inputs,
const VarNameMap& outputs, const VarNameMap& outputs,
AttributeMap attrs) { AttributeMap attrs);
auto it = op_info_map().find(type);
PADDLE_ENFORCE(it != op_info_map().end(),
"Operator '%s' has not been registered.", type);
it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op);
}
static VarNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars) {
VarNameMap ret_val;
for (auto& var : op_desc_vars) {
auto& var_names = ret_val[var.parameter()];
auto& var_names_in_proto = var.arguments();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}
return ret_val;
}
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) { static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr);
}
return CreateOp(op_desc.type(), inputs, outputs, attrs); static VarNameMap ConvertOpDescVarsToVarNameMap(
} const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars);
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op) { static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
PADDLE_ENFORCE(!op.IsNetOp(),
"Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
return grad_op;
}
static std::unordered_map<std::string, const OpInfo>& op_info_map() { static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, const OpInfo> op_info_map_; static std::unordered_map<std::string, const OpInfo> op_info_map_;

@ -164,5 +164,43 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
return ret_val; return ret_val;
} }
void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
const std::string& name, const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{input};
}
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
const std::string& name, const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{output};
}
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -138,6 +138,74 @@ class NOP : public OperatorBase {
const platform::DeviceContext& dev_ctx) const override {} const platform::DeviceContext& dev_ctx) const override {}
}; };
// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}
~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}
void Validate();
protected:
struct VariableBuilder {
OpProto::Var* var_;
VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}
VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}
// TODO(FengJiayi, yuyang18): `AsNoGradient` is a very bad name, because it
// means that input/output is not needed when calculate gradient. It does
// not mean no gradient when backward. It should be changed soon.
VariableBuilder& AsNoGradient() {
var_->set_no_gradient(true);
return *this;
}
};
VariableBuilder AddInput(const std::string& name, const std::string& comment);
VariableBuilder AddOutput(const std::string& name,
const std::string& comment);
template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
private:
void CheckNoDuplicatedInOutAttrs();
OpProto* proto_;
OpAttrChecker* op_checker_;
bool validated_{false};
};
class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};
class InferShapeContext { class InferShapeContext {
public: public:
InferShapeContext(const OperatorBase& op, const Scope& scope) InferShapeContext(const OperatorBase& op, const Scope& scope)

@ -1,7 +1,7 @@
add_subdirectory(detail) add_subdirectory(detail)
cc_library(memory SRCS memory.cc) cc_library(memory SRCS memory.cc)
cc_library(memcpy SRCS memcpy.cc DEPS device_context) cc_library(memcpy SRCS memcpy.cc)
cc_library(paddle_memory cc_library(paddle_memory
DEPS DEPS

@ -16,8 +16,6 @@ limitations under the License. */
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace memory { namespace memory {

@ -127,7 +127,7 @@ class RecurrentOp final : public framework::OperatorBase {
} }
void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; } void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); } const NetOp& stepnet() const { return *stepnet_; }
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
@ -158,7 +158,7 @@ class RecurrentGradientOp final : public framework::OperatorBase {
static const rnn::ArgumentName kArgName; static const rnn::ArgumentName kArgName;
void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; } void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); } const NetOp& stepnet() const { return *stepnet_; }
private: private:
RecurrentGradientAlgorithm alg_; RecurrentGradientAlgorithm alg_;

@ -16,5 +16,8 @@ ELSE()
set(GPU_CTX_DEPS) set(GPU_CTX_DEPS)
ENDIF() ENDIF()
cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS}) # memcpy deoends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator
system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info)

@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/memory/memory.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
@ -36,6 +37,59 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
class EigenCudaStreamDevice : public Eigen::StreamInterface {
public:
EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
Eigen::initializeDeviceProp();
}
~EigenCudaStreamDevice() override {}
void Reinitialize(const cudaStream_t* cuda_stream, GPUPlace place) {
stream_ = cuda_stream;
place_ = place;
device_prop_ = &Eigen::m_deviceProperties[place.device];
}
const cudaStream_t& stream() const override { return *stream_; }
const cudaDeviceProp& deviceProperties() const override {
return *device_prop_;
}
void* allocate(size_t num_bytes) const override {
return paddle::memory::Alloc(place_, num_bytes);
}
void deallocate(void* buffer) const override {
paddle::memory::Free(place_, buffer);
}
void* scratchpad() const override {
if (scratch_ == NULL) {
scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
}
return scratch_;
}
unsigned int* semaphore() const override {
if (semaphore_ == NULL) {
char* scratch =
static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
semaphore_ = reinterpret_cast<unsigned int*>(scratch);
PADDLE_ENFORCE(
cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
}
return semaphore_;
}
private:
GPUPlace place_;
const cudaStream_t* stream_; // not owned;
const cudaDeviceProp* device_prop_; // not owned;
mutable void* scratch_;
mutable unsigned int* semaphore_;
};
template <> template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const { Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
@ -43,19 +97,9 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly PADDLE_ENFORCE(cudaStreamCreate(&stream_));
// here will cause segment fault. We must implement a class derived from eigen_stream_.reset(new EigenCudaStreamDevice());
// Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id eigen_stream_->Reinitialize(&stream_, place);
// later. Please refer to the implementation of class EigenCudaStreamDevice
// in TensorFlow.
//
// We find that CUDA 7 introduces a new option, the per-thread default stream,
// that has two effects. Please refer to https://devblogs.nvidia.com/
// parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/
//
// So, we decide to use default stream and add default-stream per-thread nvcc
// flag. Than, two threads with two CUDADeviceContexts will run parallelly.
eigen_stream_.reset(new Eigen::CudaStreamDevice());
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
} }
@ -75,12 +119,13 @@ CUDADeviceContext::~CUDADeviceContext() {
} }
eigen_stream_.reset(); eigen_stream_.reset();
eigen_device_.reset(); eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
} }
Place CUDADeviceContext::GetPlace() const { return place_; } Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const { void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(0)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
} }
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
@ -91,6 +136,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() {
if (!cublas_handle_) { if (!cublas_handle_) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
} }
return cublas_handle_; return cublas_handle_;
} }
@ -99,10 +145,13 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) { if (!cudnn_handle_) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
} }
return cudnn_handle_; return cudnn_handle_;
} }
cudaStream_t CUDADeviceContext::stream() { return stream_; }
curandGenerator_t CUDADeviceContext::curand_generator() { curandGenerator_t CUDADeviceContext::curand_generator() {
if (!curand_generator_) { if (!curand_generator_) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
@ -110,6 +159,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
CURAND_RNG_PSEUDO_DEFAULT)); CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE( PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
} }
return curand_generator_; return curand_generator_;
} }

@ -52,6 +52,7 @@ class CPUDeviceContext : public DeviceContext {
}; };
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
public: public:
@ -76,6 +77,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return curand handle in the device context. */ /*! \brief Return curand handle in the device context. */
curandGenerator_t curand_generator(); curandGenerator_t curand_generator();
/*! \brief Return cuda stream in the device context. */
cudaStream_t stream();
// clang-format on // clang-format on
private: private:
@ -83,15 +87,16 @@ class CUDADeviceContext : public DeviceContext {
private: private:
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<Eigen::CudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
private: private:
uint64_t seed_; uint64_t seed_;
// clang-format off // clang-format off
cudnnHandle_t cudnn_handle_ = nullptr; cudaStream_t stream_{nullptr};
cublasHandle_t cublas_handle_ = nullptr; cudnnHandle_t cudnn_handle_{nullptr};
curandGenerator_t curand_generator_ = nullptr; cublasHandle_t cublas_handle_{nullptr};
curandGenerator_t curand_generator_{nullptr};
// clang-format on // clang-format on
}; };

@ -45,6 +45,7 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE(nullptr, cublas_handle); ASSERT_NE(nullptr, cublas_handle);
curandGenerator_t curand_handle = device_context->curand_generator(); curandGenerator_t curand_handle = device_context->curand_generator();
ASSERT_NE(nullptr, curand_handle); ASSERT_NE(nullptr, curand_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context; delete device_context;
} }
} }

@ -86,7 +86,7 @@ struct EnforceNotMet : public std::exception {
2 + sizeof(void*) * 2, call_stack[i], 2 + sizeof(void*) * 2, call_stack[i],
demangled, addr_offset); demangled, addr_offset);
} else { } else {
sout << string::Sprintf("%-3d %*0p %s\n", i, 2 + sizeof(void*) * 2, sout << string::Sprintf("%-3d %*0p\n", i, 2 + sizeof(void*) * 2,
call_stack[i]); call_stack[i]);
} }
} }

@ -82,10 +82,6 @@ EOF
fi fi
# To build documentation, we need to run cmake again after installing
# PaddlePaddle. This awkwardness is due to
# https://github.com/PaddlePaddle/Paddle/issues/1854. It also
# describes a solution.
if [[ ${WITH_DOC:-OFF} == "ON" ]]; then if [[ ${WITH_DOC:-OFF} == "ON" ]]; then
cat <<EOF cat <<EOF
======================================== ========================================
@ -93,11 +89,6 @@ Building documentation ...
In /paddle/build_doc In /paddle/build_doc
======================================== ========================================
EOF EOF
# build documentation need install Paddle before
make install -j `nproc`
pip install /usr/local/opt/paddle/share/wheels/*.whl
paddle version
mkdir -p /paddle/build_doc mkdir -p /paddle/build_doc
pushd /paddle/build_doc pushd /paddle/build_doc
cmake .. \ cmake .. \
@ -106,7 +97,8 @@ EOF
-DWITH_AVX=${WITH_AVX:-ON} \ -DWITH_AVX=${WITH_AVX:-ON} \
-DWITH_SWIG_PY=ON \ -DWITH_SWIG_PY=ON \
-DWITH_STYLE_CHECK=OFF -DWITH_STYLE_CHECK=OFF
make paddle_docs paddle_docs_cn make -j `nproc` gen_proto_py
make -j `nproc` paddle_docs paddle_docs_cn
popd popd
fi fi
@ -128,25 +120,6 @@ EOF
/woboq/indexgenerator/codebrowser_indexgenerator $WOBOQ_OUT /woboq/indexgenerator/codebrowser_indexgenerator $WOBOQ_OUT
fi fi
# generate deb package for current build
# FIXME(typhoonzero): should we remove paddle/scripts/deb ?
if [[ ${WITH_DEB:-ON} == "ON" ]]; then
cat <<EOF
========================================
Generating .deb package ...
========================================
EOF
set +e
cpack -D CPACK_GENERATOR='DEB' -j `nproc` ..
err_code=$?
if [ ${err_code} -ne 0 ]; then
# cat error logs if cpack failed.
cat /paddle/build/_CPack_Packages/Linux/DEB/PreinstallOutput.log
exit ${err_code}
fi
set -e
fi
cat <<EOF cat <<EOF
======================================== ========================================
Generate /paddle/build/Dockerfile ... Generate /paddle/build/Dockerfile ...
@ -166,14 +139,13 @@ EOF
fi fi
cat >> /paddle/build/Dockerfile <<EOF cat >> /paddle/build/Dockerfile <<EOF
# Use different deb file when building different type of images ADD python/dist/*.whl /
ADD *.deb /
# run paddle version to install python packages first # run paddle version to install python packages first
RUN apt-get update &&\ RUN apt-get update &&\
apt-get install -y wget python-pip && pip install -U pip && \ apt-get install -y wget python-pip && pip install -U pip && \
dpkg -i /*.deb ; apt-get install -f -y && \ pip install /*.whl; apt-get install -f -y && \
apt-get clean -y && \ apt-get clean -y && \
rm -f /*.deb && \ rm -f /*.whl && \
paddle version paddle version
${DOCKERFILE_CUDNN_DSO} ${DOCKERFILE_CUDNN_DSO}
${DOCKERFILE_GPU_ENV} ${DOCKERFILE_GPU_ENV}
@ -182,3 +154,7 @@ ADD go/cmd/master/master /usr/bin/
# default command shows the paddle version and exit # default command shows the paddle version and exit
CMD ["paddle", "version"] CMD ["paddle", "version"]
EOF EOF
set +xe
printf "If you need to install PaddlePaddle in develop docker image,"
printf "please make install or pip install build/python/dist/*.whl.\n"

@ -56,8 +56,7 @@ if [ -z "${PADDLE_NO_STAT+x}" ]; then
fi fi
fi fi
PADDLE_BIN_PATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
MYDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
if [ ! -z "${DEBUGGER}" ]; then if [ ! -z "${DEBUGGER}" ]; then
echo "Using debug command ${DEBUGGER}" echo "Using debug command ${DEBUGGER}"
@ -93,34 +92,16 @@ else:
sys.exit(0) sys.exit(0)
EOF EOF
if [ $? -eq 1 ]; then # Older version installed, or not installed at all
echo "First time run paddle, need to install some python dependencies."
# setuptools normalizes package version, so we need to use normalized
# package version for paddle python package
PYTHON_PADDLE_VERSION=$(python -c 'import packaging.version
import setuptools
print str(packaging.version.Version("@PADDLE_VERSION@"))
' 2>/dev/null)
BASEDIR=$(dirname "$0")
pip install ${BASEDIR}/../opt/paddle/share/wheels/*-${PYTHON_PADDLE_VERSION}-*.whl
if [ $? -ne 0 ]; then
echo "pip install wheels failed. "
echo "Please use 'sudo paddle' at the first time you use PaddlePaddle"
echo "PaddlePaddle will install some python dependencies automatically."
exit 1
fi
echo "Python dependencies are installed."
fi
case "$1" in case "$1" in
"train") "train")
${DEBUGGER} $MYDIR/../opt/paddle/bin/paddle_trainer ${@:2} ${DEBUGGER} $PADDLE_BIN_PATH/paddle_trainer ${@:2}
;; ;;
"merge_model") "merge_model")
${DEBUGGER} $MYDIR/../opt/paddle/bin/paddle_merge_model ${@:2} ${DEBUGGER} $PADDLE_BIN_PATH/paddle_merge_model ${@:2}
;; ;;
"pserver") "pserver")
${DEBUGGER} $MYDIR/../opt/paddle/bin/paddle_pserver_main ${@:2} ${DEBUGGER} $PADDLE_BIN_PATH/paddle_pserver_main ${@:2}
;; ;;
"dump_config") "dump_config")
python -m paddle.utils.dump_config ${@:2} python -m paddle.utils.dump_config ${@:2}
@ -129,7 +110,7 @@ case "$1" in
python -m paddle.utils.make_model_diagram ${@:2} python -m paddle.utils.make_model_diagram ${@:2}
;; ;;
"usage") "usage")
$MYDIR/../opt/paddle/bin/paddle_usage ${@:2} $PADDLE_BIN_PATH/paddle_usage ${@:2}
;; ;;
"version") "version")
version version

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save