diff --git a/CMakeLists.txt b/CMakeLists.txt
index 23bbe829ac..030bd19b3f 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -25,7 +25,6 @@ message(STATUS "CXX compiler: ${CMAKE_CXX_COMPILER}, version: "
message(STATUS "C compiler: ${CMAKE_C_COMPILER}, version: "
"${CMAKE_C_COMPILER_ID} ${CMAKE_C_COMPILER_VERSION}")
-find_package(Sphinx)
if(NOT CMAKE_CROSSCOMPILING)
find_package(CUDA QUIET)
endif(NOT CMAKE_CROSSCOMPILING)
@@ -226,5 +225,7 @@ if(WITH_PYTHON)
endif()
if(WITH_DOC)
+ find_package(Sphinx REQUIRED)
+ find_python_module(recommonmark REQUIRED)
add_subdirectory(doc)
endif()
diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake
index 966c0bafd3..0332e39d14 100644
--- a/cmake/external/mkldnn.cmake
+++ b/cmake/external/mkldnn.cmake
@@ -56,6 +56,8 @@ ExternalProject_Add(
GIT_TAG "v0.14"
PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND ""
+ # Patch MKLDNN to compile with gcc 4.8, the related issue is in intel/mkl-dnn#237.
+ PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/mkldnn.hpp ${MKLDNN_SOURCES_DIR}/src/extern_mkldnn/include/mkldnn.hpp
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${MKLDNN_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
CMAKE_ARGS -DMKLROOT=${MKLML_ROOT}
diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake
index cc75801982..06a7ae5682 100644
--- a/cmake/inference_lib.cmake
+++ b/cmake/inference_lib.cmake
@@ -70,6 +70,12 @@ copy(glog_lib
DSTS ${dst_dir} ${dst_dir}/lib
)
+set(dst_dir "${CMAKE_INSTALL_PREFIX}/third_party/boost/")
+copy(boost_lib
+ SRCS ${BOOST_INCLUDE_DIR}/boost
+ DSTS ${dst_dir}
+)
+
if(NOT PROTOBUF_FOUND)
set(dst_dir "${CMAKE_INSTALL_PREFIX}/third_party/install/protobuf")
copy(protobuf_lib
diff --git a/doc/fluid/design/dist_train/async_update.md b/doc/fluid/design/dist_train/async_update.md
index 6a0835b761..248d2ec18d 100644
--- a/doc/fluid/design/dist_train/async_update.md
+++ b/doc/fluid/design/dist_train/async_update.md
@@ -4,34 +4,37 @@
For the typical synchronous distributed training, some significant steps are as follows:
-1. A Trainer will compute the gradients and SEND them to the Parameter Server(PServer) nodes.
-1. After the PServer node received gradients came from all the Trainers, It will aggregate the
+1. A trainer process will compute the gradients and **send** them to the parameter server (PS) nodes.
+1. After the PS node received gradients came from all the Trainers, It will aggregate the
gradient variables for the same parameter into one gradient variable and then apply the aggregated
gradient to the respective parameter, finally using an optimize algorithms(SGD, Monument...)
to update the parameters.
-1. The Trainer would wait for the PServers finished the optimize stage, and GET the parameters from PServer,
+1. The Trainer would wait for the PS finished the optimize stage, and GET the parameters from PS,
so all the Trainers would get the same parameters.
-In the synchronously distributed training, there should be a `Barrier` to synchronise the
-parameters after the optimizing stage. The performance of a distributed training job would
-depend on the slowest node if there were hundreds or thousands of training nodes in a
-Job, the performance of synchronously distributed training might be very poor because of
-the slow node. So this design doc would introduce an approach to implement
-*asynchronously* distributed training in PaddlePaddle Fluid.
+In Synchronous Distributed Training, there is a **barrier** on each PS to wait until all trainers processes
+have completed running current mini-batch. After that, all trainers can continue to run the next
+mini-batch. So, we can find that the overall performance of Synchronous Distributed Training depends
+on the slowest node.
+
+In Asynchronous Distributed Training, we don't need to wait for a global mini-bach, the optimizer on
+the PS will run immediately when the gradient is uploaded to the PS from one trainer. This mode would
+train such models that achieve scaling, better throughput. In this design doc, we will introduce how to
+implement the Asynchronous Distributed Training base on PaddlePaddle Fluid.
## Design
-As the figure above, we describe a global view of asynchronously update process and use
+As the figure above, we describe a global view of the asynchronous update process and use
the parameter `w1` as an example to introduce the steps:
1. For each gradient variables, they may distribute on different GPU card and aggregate
them while they are all calculated.
-1. Split the gradient variable into multiple blocks according to the number of PServer
+1. Split the gradient variable into multiple blocks according to the number of PS
instances and then send them.
-1. PServer would run an `Optimize Block` using a specified optimize algorithm to update
+1. PS would run an `Optimize Block` using a specified optimize algorithm to update
the specified parameter.
-1. The trainer will fetch latest parameter from PServer before running forward Op which depends
+1. The trainer will fetch the latest parameter from PS before running forward Op which depends
on the specified parameter.
1. Broadcast the received variable into multiple GPU cards and continue to run the next
mini-batch.
@@ -40,8 +43,8 @@ mini-batch.
- For the multiple devices distributed training, we need to aggregate the gradient
variables which placed on different devices firstly and then schedule a `SendVars` Operator to
-send the gradient variables to the multiple PServer instances.
-- Schedule `FetchVars` operator to fetch the latest parameter from PServer before running
+send the gradient variables to the multiple PS instances.
+- Schedule `FetchVars` operator to fetch the latest parameter from PS before running
the forward ops.
- There could be a large number of gradient variables to be sent, so we need to use another
thread pool(IO Threadpool) whose a number of the schedulable threads is larger than the
diff --git a/doc/v2/build_and_install/build_from_source_cn.rst b/doc/v2/build_and_install/build_from_source_cn.rst
index 115b92a338..f846928954 100644
--- a/doc/v2/build_and_install/build_from_source_cn.rst
+++ b/doc/v2/build_and_install/build_from_source_cn.rst
@@ -19,8 +19,9 @@
----------------
PaddlePaddle需要使用Docker环境完成编译,这样可以免去单独安装编译依赖的步骤,可选的不同编译环境Docker镜像
-可以在 `这里 `_ 找到。或者
-参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。
+可以在 `这里 `_ 找到,您也可以
+在 `这里 `_ 找到 paddle_manylinux_devel
+镜像的编译以及使用方法。或者参考下述可选步骤,从源码中构建用于编译PaddlePaddle的Docker镜像。
如果您选择不使用Docker镜像,则需要在本机安装下面章节列出的 `编译依赖`_ 之后才能开始编译的步骤。
diff --git a/doc/v2/build_and_install/build_from_source_en.rst b/doc/v2/build_and_install/build_from_source_en.rst
index 8fef9e7347..d1b5b88dff 100644
--- a/doc/v2/build_and_install/build_from_source_en.rst
+++ b/doc/v2/build_and_install/build_from_source_en.rst
@@ -22,6 +22,8 @@ How To Build
You need to use Docker to build PaddlePaddle
to avoid installing dependencies by yourself. We have several pre-built
Docker images `here `_ ,
+you can also find how to build and use paddle_manylinux_devel Docker image from
+`here `_
Or you can build your own image from source as the optional step below:
.. code-block:: bash
diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt
index ab71e0e63c..ed1e70c646 100644
--- a/paddle/fluid/framework/CMakeLists.txt
+++ b/paddle/fluid/framework/CMakeLists.txt
@@ -5,11 +5,11 @@ proto_library(framework_proto SRCS framework.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
-
+cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
if(WITH_GPU)
- nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place memory device_context framework_proto)
+ nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type)
else()
- cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place memory device_context framework_proto)
+ cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type)
endif()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
diff --git a/paddle/fluid/framework/data_type.cc b/paddle/fluid/framework/data_type.cc
new file mode 100644
index 0000000000..b9c90cb0c3
--- /dev/null
+++ b/paddle/fluid/framework/data_type.cc
@@ -0,0 +1,101 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/framework/data_type.h"
+#include
+#include
+#include
+
+namespace paddle {
+namespace framework {
+
+struct DataTypeMap {
+ std::unordered_map cpp_to_proto_;
+ std::unordered_map proto_to_cpp_;
+ std::unordered_map proto_to_str_;
+ std::unordered_map cpp_to_size_;
+};
+
+static DataTypeMap* InitDataTypeMap();
+static DataTypeMap& gDataTypeMap() {
+ static DataTypeMap* g_data_type_map_ = InitDataTypeMap();
+ return *g_data_type_map_;
+}
+
+template
+static inline void RegisterType(DataTypeMap* map,
+ proto::VarType::Type proto_type,
+ const std::string& name) {
+ map->proto_to_cpp_.emplace(static_cast(proto_type), typeid(T));
+ map->cpp_to_proto_.emplace(typeid(T), proto_type);
+ map->proto_to_str_.emplace(static_cast(proto_type), name);
+ map->cpp_to_size_.emplace(typeid(T), sizeof(T));
+}
+
+static DataTypeMap* InitDataTypeMap() {
+ auto retv = new DataTypeMap();
+
+#define RegType(cc_type, proto_type) \
+ RegisterType(retv, proto_type, #cc_type)
+
+ // NOTE: Add your customize type here.
+ RegType(platform::float16, proto::VarType::FP16);
+ RegType(float, proto::VarType::FP32);
+ RegType(double, proto::VarType::FP64);
+ RegType(int, proto::VarType::INT32);
+ RegType(int64_t, proto::VarType::INT64);
+ RegType(bool, proto::VarType::BOOL);
+ RegType(size_t, proto::VarType::SIZE_T);
+ RegType(int16_t, proto::VarType::INT16);
+
+#undef RegType
+ return retv;
+}
+
+proto::VarType::Type ToDataType(std::type_index type) {
+ auto it = gDataTypeMap().cpp_to_proto_.find(type);
+ if (it != gDataTypeMap().cpp_to_proto_.end()) {
+ return it->second;
+ }
+ PADDLE_THROW("Not support %s as tensor type", type.name());
+}
+
+std::type_index ToTypeIndex(proto::VarType::Type type) {
+ auto it = gDataTypeMap().proto_to_cpp_.find(static_cast(type));
+ if (it != gDataTypeMap().proto_to_cpp_.end()) {
+ return it->second;
+ }
+ PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
+ static_cast(type));
+}
+
+std::string DataTypeToString(const proto::VarType::Type type) {
+ auto it = gDataTypeMap().proto_to_str_.find(static_cast(type));
+ if (it != gDataTypeMap().proto_to_str_.end()) {
+ return it->second;
+ }
+ PADDLE_THROW("Not support proto::VarType::Type(%d) as tensor type",
+ static_cast(type));
+}
+
+size_t SizeOfType(std::type_index type) {
+ auto it = gDataTypeMap().cpp_to_size_.find(type);
+ if (it != gDataTypeMap().cpp_to_size_.end()) {
+ return it->second;
+ }
+ PADDLE_THROW("Not support %s as tensor type", type.name());
+}
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h
index 2a528eb3aa..4b9f572ec5 100644
--- a/paddle/fluid/framework/data_type.h
+++ b/paddle/fluid/framework/data_type.h
@@ -17,51 +17,14 @@ limitations under the License. */
#include
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
+
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
-inline proto::VarType::Type ToDataType(std::type_index type) {
- if (typeid(platform::float16).hash_code() == type.hash_code()) {
- return proto::VarType::FP16;
- } else if (typeid(const float).hash_code() == type.hash_code()) {
- // CPPLint complains Using C-style cast. Use static_cast() instead
- // One fix to this is to replace float with const float because
- // typeid(T) == typeid(const T)
- // http://en.cppreference.com/w/cpp/language/typeid
- return proto::VarType::FP32;
- } else if (typeid(const double).hash_code() == type.hash_code()) {
- return proto::VarType::FP64;
- } else if (typeid(const int).hash_code() == type.hash_code()) {
- return proto::VarType::INT32;
- } else if (typeid(const int64_t).hash_code() == type.hash_code()) {
- return proto::VarType::INT64;
- } else if (typeid(const bool).hash_code() == type.hash_code()) {
- return proto::VarType::BOOL;
- } else {
- PADDLE_THROW("Not supported");
- }
-}
-
-inline std::type_index ToTypeIndex(proto::VarType::Type type) {
- switch (type) {
- case proto::VarType::FP16:
- return typeid(platform::float16);
- case proto::VarType::FP32:
- return typeid(float);
- case proto::VarType::FP64:
- return typeid(double);
- case proto::VarType::INT32:
- return typeid(int);
- case proto::VarType::INT64:
- return typeid(int64_t);
- case proto::VarType::BOOL:
- return typeid(bool);
- default:
- PADDLE_THROW("Not support type %d", type);
- }
-}
+extern proto::VarType::Type ToDataType(std::type_index type);
+extern std::type_index ToTypeIndex(proto::VarType::Type type);
template
inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
@@ -89,32 +52,12 @@ inline void VisitDataType(proto::VarType::Type type, Visitor visitor) {
}
}
-inline std::string DataTypeToString(const proto::VarType::Type type) {
- switch (type) {
- case proto::VarType::FP16:
- return "float16";
- case proto::VarType::FP32:
- return "float32";
- case proto::VarType::FP64:
- return "float64";
- case proto::VarType::INT16:
- return "int16";
- case proto::VarType::INT32:
- return "int32";
- case proto::VarType::INT64:
- return "int64";
- case proto::VarType::BOOL:
- return "bool";
- default:
- PADDLE_THROW("Not support type %d", type);
- }
-}
-
+extern std::string DataTypeToString(const proto::VarType::Type type);
+extern size_t SizeOfType(std::type_index type);
inline std::ostream& operator<<(std::ostream& out,
const proto::VarType::Type& type) {
out << DataTypeToString(type);
return out;
}
-
} // namespace framework
} // namespace paddle
diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h
new file mode 100644
index 0000000000..91bdfe6134
--- /dev/null
+++ b/paddle/fluid/framework/details/build_strategy.h
@@ -0,0 +1,36 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+struct BuildStrategy {
+ enum class ReduceStrategy { kAllReduce = 0, kReduce = 1 };
+
+ enum class GradientScaleStrategy {
+ kCoeffNumDevice = 0,
+ kOne = 1,
+ kCustomized = 2,
+ };
+
+ ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
+ GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};
+};
+
+} // namespace details
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/fluid/framework/details/execution_strategy.h b/paddle/fluid/framework/details/execution_strategy.h
new file mode 100644
index 0000000000..e8d510ec95
--- /dev/null
+++ b/paddle/fluid/framework/details/execution_strategy.h
@@ -0,0 +1,29 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+namespace paddle {
+namespace framework {
+namespace details {
+
+struct ExecutionStrategy {
+ size_t num_threads_{0};
+ bool use_event_{true};
+ bool allow_op_delay_{false};
+};
+
+} // namespace details
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 4755559f8d..45bad58145 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -37,31 +37,26 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::string &loss_var_name,
const std::unordered_set ¶ms,
const std::vector &local_scopes,
- platform::NCCLContextMap *nccl_ctxs, bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards)
+ platform::NCCLContextMap *nccl_ctxs, const BuildStrategy &strategy)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs),
- balance_parameter_opt_between_cards_(
- balance_parameter_opt_between_cards) {
+ strategy_(strategy) {
#else
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
const std::vector &places,
const std::string &loss_var_name,
const std::unordered_set ¶ms,
- const std::vector &local_scopes, bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards)
+ const std::vector &local_scopes, const BuildStrategy &strategy)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
- balance_parameter_opt_between_cards_(
- balance_parameter_opt_between_cards) {
+ strategy_(strategy) {
#endif
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
- use_default_grad_scale_ = use_default_grad_scale;
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
@@ -146,7 +141,8 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_
- if (use_default_grad_scale_) {
+ if (strategy_.gradient_scale_ !=
+ BuildStrategy::GradientScaleStrategy::kCustomized) {
CreateScaleLossGradOp(&result);
}
is_forwarding = false;
@@ -165,19 +161,22 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build(
// broadcast, and each gradient is only broadcast once.
for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
- if (balance_parameter_opt_between_cards_) {
- CreateReduceOp(&result, og, cur_device_id);
- var_name_on_devices[cur_device_id].emplace(og);
- bcast_var_name_set[cur_device_id].emplace(
- og.substr(0, og.size() - strlen(kGradVarSuffix)));
- cur_device_id = (cur_device_id + 1) % places_.size();
- } else {
- if (IsSparseGradient(var_types, og)) {
- CreateReduceOp(&result, og, 0);
- CreateBroadcastOp(&result, og, 0);
- } else {
- InsertNCCLAllReduceOp(&result, og);
- }
+ switch (strategy_.reduce_) {
+ case BuildStrategy::ReduceStrategy::kReduce:
+ CreateReduceOp(&result, og, cur_device_id);
+ var_name_on_devices[cur_device_id].emplace(og);
+ bcast_var_name_set[cur_device_id].emplace(
+ og.substr(0, og.size() - strlen(kGradVarSuffix)));
+ cur_device_id = (cur_device_id + 1) % places_.size();
+ break;
+ case BuildStrategy::ReduceStrategy::kAllReduce:
+ if (IsSparseGradient(var_types, og)) {
+ CreateReduceOp(&result, og, 0);
+ CreateBroadcastOp(&result, og, 0);
+ } else {
+ InsertNCCLAllReduceOp(&result, og);
+ }
+ break;
}
}
}
@@ -303,7 +302,7 @@ bool MultiDevSSAGraphBuilder::IsParameterGradientOnce(
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const std::vector> &var_name_on_devices,
const OpDesc &op) const {
- if (!balance_parameter_opt_between_cards_) {
+ if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index 3a3e9e3b85..4f70852188 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -17,6 +17,7 @@
#include
#include
+#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
namespace paddle {
@@ -36,15 +37,13 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
const std::unordered_set ¶ms,
const std::vector &local_scopes,
platform::NCCLContextMap *nccl_ctxs,
- bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards);
+ const BuildStrategy &strategy);
#else
MultiDevSSAGraphBuilder(const std::vector &places,
const std::string &loss_var_name,
const std::unordered_set ¶ms,
const std::vector &local_scopes,
- bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards);
+ const BuildStrategy &strategy);
#endif
std::unique_ptr Build(const ProgramDesc &program) const override;
@@ -62,8 +61,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap *nccl_ctxs_;
#endif
- bool balance_parameter_opt_between_cards_;
- bool use_default_grad_scale_;
bool IsScaleLossOp(const OpDesc &op) const;
@@ -105,6 +102,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsSparseGradient(
const std::unordered_map &var_types,
const std::string &og) const;
+
+ private:
+ BuildStrategy strategy_;
};
} // namespace details
} // namespace framework
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
index e90523ebe8..ef263d82c5 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc
@@ -18,18 +18,17 @@ namespace paddle {
namespace framework {
namespace details {
ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
- size_t num_threads, bool use_event,
- const std::vector &local_scopes,
+ const ExecutionStrategy &strategy, const std::vector &local_scopes,
const std::vector &places,
- std::unique_ptr &&graph, bool allow_op_delay)
+ std::unique_ptr &&graph)
: SSAGraphExecutor(std::move(graph)),
- pool_(num_threads >= 2 ? new ::ThreadPool(num_threads) : nullptr),
+ pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
+ : nullptr),
local_scopes_(local_scopes),
places_(places),
fetch_ctxs_(places),
- use_event_(use_event),
running_ops_(0),
- allow_op_delay_(allow_op_delay) {}
+ strategy_(strategy) {}
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector &fetch_tensors) {
@@ -86,7 +85,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
//
// NOTE: DelayedOps have a lower priority. It will be scheduled after all
// ready_ops have been performed.
- if (ready_ops.empty() && allow_op_delay_ && running_ops_ == 0) {
+ if (ready_ops.empty() && strategy_.allow_op_delay_ && running_ops_ == 0) {
run_all_ops(delayed_ops);
} else {
run_all_ops(ready_ops);
@@ -113,7 +112,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto &deps = pending_ops[op];
--deps;
if (deps == 0) {
- if (op->IsMultiDeviceTransfer() && allow_op_delay_) {
+ if (op->IsMultiDeviceTransfer() && strategy_.allow_op_delay_) {
delayed_ops.insert(op);
} else {
ready_ops.insert(op);
@@ -191,7 +190,7 @@ void ThreadedSSAGraphExecutor::RunOp(
auto op_run = [ready_var_q, op, this] {
try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
- op->Run(use_event_);
+ op->Run(strategy_.use_event_);
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->Outputs());
diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
index f18a88526b..1f7f88d752 100644
--- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
+++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h
@@ -23,6 +23,7 @@
#include
#include "ThreadPool.h" // ThreadPool in thrird party
#include "paddle/fluid/framework/blocking_queue.h"
+#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
@@ -34,11 +35,10 @@ namespace details {
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
- ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
+ ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector &local_scopes,
const std::vector &places,
- std::unique_ptr &&graph,
- bool allow_op_delay);
+ std::unique_ptr &&graph);
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
@@ -55,10 +55,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::vector local_scopes_;
std::vector places_;
platform::DeviceContextPool fetch_ctxs_;
- const bool use_event_;
std::unique_ptr exception_;
std::atomic running_ops_;
- bool allow_op_delay_;
void InsertPendingOp(std::unordered_map *pending_ops,
OpHandleBase *op_instance) const;
@@ -74,6 +72,9 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::unordered_map *pending_ops,
std::unordered_set *pending_vars,
BlockingQueue *ready_vars, FeedFetchList *fetch_data);
+
+ private:
+ ExecutionStrategy strategy_;
};
} // namespace details
diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto
index 96f53dc1bc..d2558f111f 100644
--- a/paddle/fluid/framework/framework.proto
+++ b/paddle/fluid/framework/framework.proto
@@ -101,6 +101,8 @@ message VarType {
FP16 = 4;
FP32 = 5;
FP64 = 6;
+ // Tensor is used in C++.
+ SIZE_T = 19;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
diff --git a/paddle/fluid/framework/op_kernel_type_test.cc b/paddle/fluid/framework/op_kernel_type_test.cc
index d37ce149ce..db95861c51 100644
--- a/paddle/fluid/framework/op_kernel_type_test.cc
+++ b/paddle/fluid/framework/op_kernel_type_test.cc
@@ -27,7 +27,7 @@ TEST(OpKernelType, ToString) {
LibraryType::kCUDNN);
ASSERT_EQ(paddle::framework::KernelTypeToString(op_kernel_type),
- "data_type[float32]:data_layout[NCHW]:place[CPUPlace]:library_type["
+ "data_type[float]:data_layout[NCHW]:place[CPUPlace]:library_type["
"CUDNN]");
}
diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h
index d373c48b1a..a4eb6f706e 100644
--- a/paddle/fluid/framework/operator.h
+++ b/paddle/fluid/framework/operator.h
@@ -192,6 +192,10 @@ class ExecutionContext {
return op_.Attr(name);
}
+ bool HasInput(const std::string& name) const { return op_.HasInputs(name); }
+
+ bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); }
+
size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size();
}
diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc
index 20ef7e09f6..50c3468d55 100644
--- a/paddle/fluid/framework/parallel_executor.cc
+++ b/paddle/fluid/framework/parallel_executor.cc
@@ -52,13 +52,13 @@ std::vector &ParallelExecutor::GetLocalScopes() {
}
ParallelExecutor::ParallelExecutor(
- size_t num_threads, bool use_event,
const std::vector &places,
const std::unordered_set ¶ms,
const std::unordered_set &bcast_vars,
const ProgramDesc &main_program, const std::string &loss_var_name,
- Scope *scope, const std::vector &local_scopes, bool allow_op_delay,
- bool use_default_grad_scale, bool balance_parameter_opt_between_cards)
+ Scope *scope, const std::vector &local_scopes,
+ const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy,
+ size_t num_trainers, size_t trainer_id)
: member_(new ParallelExecutorPrivate(places)) {
member_->global_scope_ = scope;
@@ -80,7 +80,13 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
#ifdef PADDLE_WITH_CUDA
- member_->nccl_ctxs_.reset(new platform::NCCLContextMap(member_->places_));
+ auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
+ ncclUniqueId *nccl_id = nullptr;
+ if (nccl_id_var != nullptr) {
+ nccl_id = nccl_id_var->GetMutable();
+ }
+ member_->nccl_ctxs_.reset(new platform::NCCLContextMap(
+ member_->places_, nccl_id, num_trainers, trainer_id));
#endif
if (platform::is_gpu_place(places[0]) && member_->local_scopes_.size() != 1 &&
local_scopes.empty()) { // Is CUDA
@@ -93,18 +99,16 @@ ParallelExecutor::ParallelExecutor(
#ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder(
member_->places_, loss_var_name, params, member_->local_scopes_,
- member_->nccl_ctxs_.get(), use_default_grad_scale,
- balance_parameter_opt_between_cards);
+ member_->nccl_ctxs_.get(), build_strategy);
#else
- details::MultiDevSSAGraphBuilder builder(
- member_->places_, loss_var_name, params, member_->local_scopes_,
- use_default_grad_scale, balance_parameter_opt_between_cards);
+ details::MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name,
+ params, member_->local_scopes_,
+ build_strategy);
#endif
auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
- num_threads, use_event, member_->local_scopes_, places, std::move(graph),
- allow_op_delay));
+ exec_strategy, member_->local_scopes_, places, std::move(graph)));
// Step 3. Create vars in each scope;
for (auto *var : main_program.Block(0).AllVars()) {
diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h
index b251fc9141..5247e79064 100644
--- a/paddle/fluid/framework/parallel_executor.h
+++ b/paddle/fluid/framework/parallel_executor.h
@@ -14,56 +14,60 @@ limitations under the License. */
#pragma once
+#include
#include
#include
#include
+#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
-
namespace paddle {
namespace framework {
class ParallelExecutorPrivate;
+using details::BuildStrategy;
+using details::ExecutionStrategy;
+
class ParallelExecutor {
DISABLE_COPY_AND_ASSIGN(ParallelExecutor);
public:
- explicit ParallelExecutor(size_t num_threads, bool use_event,
- const std::vector& places,
- const std::unordered_set& params,
- const std::unordered_set& bcast_vars,
- const ProgramDesc& main_program,
- const std::string& loss_var_name, Scope* scope,
- const std::vector& local_scopes,
- bool allow_op_delay, bool use_default_grad_scale,
- bool balance_parameter_opt_between_cards);
+ explicit ParallelExecutor(const std::vector &places,
+ const std::unordered_set ¶ms,
+ const std::unordered_set &bcast_vars,
+ const ProgramDesc &main_program,
+ const std::string &loss_var_name, Scope *scope,
+ const std::vector &local_scopes,
+ const ExecutionStrategy &exec_strategy,
+ const BuildStrategy &build_strategy,
+ size_t num_trainers = 1, size_t trainer_id = 0);
~ParallelExecutor();
- std::vector& GetLocalScopes();
+ std::vector &GetLocalScopes();
/**
* Feed tensors to local scopes. The size of tensors should be equal to the
* size of local scopes.
*/
void FeedTensorsIntoLocalScopes(
- const std::vector>& tensors);
+ const std::vector> &tensors);
void FeedAndSplitTensorIntoLocalScopes(
- const std::unordered_map& tensors);
+ const std::unordered_map &tensors);
- void Run(const std::vector& fetch_tensors,
- const std::string& fetched_var_name);
+ void Run(const std::vector &fetch_tensors,
+ const std::string &fetched_var_name);
- void BCastParamsToGPUs(const std::unordered_set& vars) const;
+ void BCastParamsToGPUs(const std::unordered_set &vars) const;
private:
- ParallelExecutorPrivate* member_;
+ ParallelExecutorPrivate *member_;
};
} // namespace framework
diff --git a/paddle/fluid/framework/tensor_impl.h b/paddle/fluid/framework/tensor_impl.h
index f49d1a47a3..0a1db7758b 100644
--- a/paddle/fluid/framework/tensor_impl.h
+++ b/paddle/fluid/framework/tensor_impl.h
@@ -13,54 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
-
-template
-struct SizeOfTypeFunctor;
-
-template
-struct SizeOfTypeFunctor {
- size_t operator()(std::type_index type) const {
- if (typeid(T).hash_code() == type.hash_code()) {
- return sizeof(T);
- } else {
- return 0UL;
- }
- }
-};
-
-template <>
-struct SizeOfTypeFunctor<> {
- size_t operator()(std::type_index type) const { return 0UL; }
-};
-
-template
-struct SizeOfTypeFunctor {
- size_t operator()(std::type_index type) const {
- SizeOfTypeFunctor head;
- size_t head_size = head(type);
- if (head_size != 0) {
- return head_size;
- }
- SizeOfTypeFunctor tail;
- return tail(type);
- }
-};
-
-static inline size_t SizeOfType(std::type_index type) {
- SizeOfTypeFunctor
- functor;
- size_t size = functor(type);
- PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name());
- return size;
-}
-
+extern size_t SizeOfType(std::type_index type);
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tensor holds no memory. Call Tensor::mutable_data first.");
diff --git a/paddle/fluid/inference/analysis/CMakeLists.txt b/paddle/fluid/inference/analysis/CMakeLists.txt
index de7becae4d..47929ef749 100644
--- a/paddle/fluid/inference/analysis/CMakeLists.txt
+++ b/paddle/fluid/inference/analysis/CMakeLists.txt
@@ -1 +1,2 @@
-cc_library(dot SRCS dot.cc)
+cc_library(analysis SRCS dot.cc node.cc node.h)
+cc_test(test_node SRCS node_tester.cc DEPS analysis)
diff --git a/paddle/fluid/inference/analysis/device.h b/paddle/fluid/inference/analysis/device.h
new file mode 100644
index 0000000000..585c992329
--- /dev/null
+++ b/paddle/fluid/inference/analysis/device.h
@@ -0,0 +1,24 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+#pragma once
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+enum class Device { CPU, GPU };
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/dot.h b/paddle/fluid/inference/analysis/dot.h
index 3359987874..4bf1840fdd 100644
--- a/paddle/fluid/inference/analysis/dot.h
+++ b/paddle/fluid/inference/analysis/dot.h
@@ -21,6 +21,7 @@
#include
#include
+#include
#include
#include
diff --git a/paddle/fluid/inference/analysis/dot_tester.cc b/paddle/fluid/inference/analysis/dot_tester.cc
new file mode 100644
index 0000000000..56ceb9bd5d
--- /dev/null
+++ b/paddle/fluid/inference/analysis/dot_tester.cc
@@ -0,0 +1,62 @@
+// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/inference/analysis/dot.h"
+
+#include
+#include
+#include "paddle/fluid/inference/analysis/data_flow_graph.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+class DotTester : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ std::vector attrs({{"title", "hello"}});
+ dot.reset(new Dot(attrs));
+ dot->AddNode("a", {Dot::Attr{"shape", "box"}, Dot::Attr("color", "blue")});
+ dot->AddNode("b", {});
+ dot->AddNode("c", {});
+ dot->AddEdge("a", "b", {});
+ dot->AddEdge("b", "c", {});
+ dot->AddEdge("a", "c", {});
+ }
+
+ std::unique_ptr dot;
+};
+
+TEST_F(DotTester, Build) {
+ auto codes = dot->Build();
+ // Output the DOT language code, the generated codes are too long to compare
+ // the string.
+ //
+ // The output is
+ //
+ // digraph G {
+ // title="hello"
+ // node_1
+ // node_2
+ // node_0[label="a" shape="box" color="blue"]
+ // node_0->node_1
+ // node_1->node_2
+ // node_0->node_2
+ // } // end G
+ LOG(INFO) << '\n' << codes;
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h
new file mode 100644
index 0000000000..b2d06c5d63
--- /dev/null
+++ b/paddle/fluid/inference/analysis/helper.h
@@ -0,0 +1,74 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+template
+class iterator_range {
+ IteratorT begin_, end_;
+
+ public:
+ template
+ explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
+
+ iterator_range(const IteratorT &begin, const IteratorT &end)
+ : begin_(begin), end_(end) {}
+
+ const IteratorT &begin() const { return begin_; }
+ const IteratorT &end() const { return end_; }
+};
+
+/*
+ * An registry helper class, with its records keeps the order they registers.
+ */
+template
+class OrderedRegistry {
+ public:
+ T *Register(const std::string &name, T *x) {
+ PADDLE_ENFORCE(!dic_.count(name));
+ dic_[name] = data_.size();
+ data_.emplace_back(std::unique_ptr(x));
+ return data_.back().get();
+ }
+
+ T *Lookup(const std::string &name) {
+ auto it = dic_.find(name);
+ if (it == dic_.end()) return nullptr;
+ return data_[it->second].get();
+ }
+
+ protected:
+ std::unordered_map dic_;
+ std::vector> data_;
+};
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
+
+#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
+ \
+ type__(const type__ &) = delete; \
+ \
+ void operator=(const type__ &) = delete;
diff --git a/paddle/fluid/inference/analysis/node.cc b/paddle/fluid/inference/analysis/node.cc
new file mode 100644
index 0000000000..fe06052608
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/inference/analysis/node.h"
+#include "glog/logging.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+std::vector Value::dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled,rounded"),
+ Dot::Attr("shape", "box"),
+ Dot::Attr("fillcolor", "red")});
+}
+
+std::vector Function::dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled,rounded"),
+ Dot::Attr("shape", "diamond"),
+ Dot::Attr("fillcolor", "yellow")});
+}
+
+Node *NodeMap::Create(Node::Type type) {
+ switch (type) {
+ case Node::Type::kFunction:
+ nodes_.emplace_back(new Function);
+ break;
+ case Node::Type::kValue:
+ nodes_.emplace_back(new Value);
+ break;
+ default:
+ PADDLE_THROW("Not supported node type.");
+ }
+ nodes_.back()->id_ = size() - 1;
+ return nodes_.back().get();
+}
+
+Node *NodeMap::GetMutable(size_t id) {
+ PADDLE_ENFORCE_GT(size(), id);
+ return nodes_[id].get();
+}
+
+const Node &NodeMap::Get(size_t id) const {
+ PADDLE_ENFORCE_GT(size(), id);
+ return *nodes_[id].get();
+}
+
+void NodeMap::Delete(size_t id) {
+ PADDLE_ENFORCE_LT(id, size());
+ nodes_[id]->SetDeleted();
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/node.h b/paddle/fluid/inference/analysis/node.h
new file mode 100644
index 0000000000..07cb7669f9
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node.h
@@ -0,0 +1,235 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+/*
+ * This file defines the Node class and its subclasses. A Node is the basis
+ * analysis element in a computation graph.
+ * There are basically two kinds of nodes, the function node and value node.
+ */
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+#include "paddle/fluid/inference/analysis/device.h"
+#include "paddle/fluid/inference/analysis/dot.h"
+#include "paddle/fluid/inference/analysis/helper.h"
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+class NodeMap;
+
+/*
+ * Node Representation.
+ *
+ * This is a very important class for analysis. It is the base class of all
+ * nodes computed by a program that may be used as operands to other nodes.
+ * Node is the super class of other important classes such as Function and
+ * Value, some nodes can have a name.
+ */
+class Node {
+ public:
+ // Node type. NOTE the new node types should add here.
+ enum class Type { kNone = -1, kFunction, kValue, kFunctionBlock };
+
+ Node() = default;
+
+ struct Attr;
+
+ // Cast to a subclass type, Function for example.
+ template
+ Subclass &As() {
+ return *dynamic_cast(this);
+ }
+
+ // Formatted representation of this Node.
+ virtual std::string repr() const {
+ return name() + "(" + std::to_string(id()) + ")";
+ }
+
+ // DOT node representation. One Node type can customize its own node
+ // representation.
+ virtual std::vector dot_attrs() const {
+ return std::vector({Dot::Attr("style", "filled")});
+ }
+
+ // Get an additional attribute and convert it to T data type. NOTE this will
+ // silently create a new attribute if not exists.
+ Attr &attr(const std::string &name) { return attrs_[name]; }
+
+ int id() const { return id_; }
+
+ bool deleted() const { return deleted_; }
+ void SetDeleted() { deleted_ = true; }
+
+ void SetName(const std::string &name) { name_ = name; }
+ const std::string &name() const { return name_; }
+
+ void SetType(Type type) { type_ = type; }
+ Type type() const { return type_; }
+
+ void *extra_info() const { return extra_info_; }
+ void SetExtraInfo(void *extra_info) { extra_info_ = extra_info; }
+
+ // Input links.
+ std::vector inlinks;
+ // Output links.
+ std::vector outlinks;
+
+ // A helper class to maintain the status from Pass.
+ // TODO(superjomn) add a checker here to ensure the T is primary.
+ struct Attr {
+ // NOTE T should be a primary type or a struct combined by several primary
+ // types.
+ // NOTE the STL containers should not use here.
+ // Some usages
+ // Attr attr;
+ // T data;
+ // attr.data.assign((char*)data, sizeof(data));
+
+ bool &Bool() { return As(); }
+ float &Float() { return As(); }
+ int32_t &Int32() { return As(); }
+ int64_t &Int64() { return As(); }
+
+ private:
+ template
+ T &As() {
+ // init storage in the first usage.
+ if (data_.empty()) {
+ VLOG(4) << "resize data to " << sizeof(T);
+ type_hash_ = typeid(T).hash_code();
+ data_.resize(sizeof(T));
+ }
+ PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(), "type not matched");
+ PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
+ return *reinterpret_cast(&data_[0]);
+ }
+
+ private:
+ std::string data_;
+ size_t type_hash_{std::numeric_limits::max()};
+ };
+
+ virtual ~Node() {}
+
+ friend class NodeMap;
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Node);
+
+ protected:
+ // The id number not the name is a node's unique identifier in the computation
+ // graph.
+ int id_{-1};
+ std::string name_;
+ Type type_{Type::kNone};
+ // Mark this node is deleted by some pass.
+ bool deleted_{false};
+
+ void *extra_info_;
+
+ mutable std::unordered_map attrs_;
+};
+
+class Function;
+/*
+ * Value represents a value node, it has some attributes including dims, data
+ * type and so on.
+ */
+class Value : public Node {
+ public:
+ enum class DataType { kInt32, kInt64, kFloat32, kFloat64 };
+ using Dims = std::vector;
+
+ void SetDataType(DataType data_type) { data_type_ = data_type; }
+ DataType data_type() const { return data_type_; }
+
+ void SetDims(const Dims &dims) { dims_ = dims; }
+ const Dims &dims() const { return dims_; }
+
+ Device device() const { return device_; }
+ void SetDevice(Device device) { device_ = device; }
+
+ std::vector dot_attrs() const override;
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Value);
+
+ protected:
+ Value() { SetType(Node::Type::kValue); }
+ friend class NodeMap;
+
+ private:
+ DataType data_type_;
+ Dims dims_;
+ Device device_;
+};
+
+/*
+ * Function represents any kind of executable concepts that takes several Values
+ * as input, and outputs several Values.
+ */
+class Function : public Node {
+ public:
+ std::vector dot_attrs() const override;
+
+ // Get the operator's type from Desc.
+ const std::string &func_type() const { return func_type_; }
+ // Set the operator's type.
+ void SetFuncType(const std::string &func_type) { func_type_ = func_type; }
+
+ PADDLE_DISALLOW_COPY_AND_ASSIGN(Function);
+
+ protected:
+ std::string func_type_;
+ Function() { SetType(Node::Type::kFunction); }
+ friend class NodeMap;
+};
+
+/*
+ * FunctionBlock is a Node that contains a sub-graph multiple Node.
+ */
+struct FunctionBlock : public Node {
+ std::string repr() const override { return "block-" + std::to_string(id()); }
+ std::vector subgraph;
+};
+
+class NodeMap {
+ public:
+ // Create a new node with type.
+ Node *Create(Node::Type type);
+
+ // Get a node by its id.
+ Node *GetMutable(size_t id);
+
+ const Node &Get(size_t id) const;
+
+ void Delete(size_t id);
+
+ const std::vector> &nodes() { return nodes_; }
+
+ size_t size() const { return nodes_.size(); }
+
+ private:
+ std::vector> nodes_;
+ std::unordered_map map_;
+};
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/analysis/node_tester.cc b/paddle/fluid/inference/analysis/node_tester.cc
new file mode 100644
index 0000000000..47fea0fdff
--- /dev/null
+++ b/paddle/fluid/inference/analysis/node_tester.cc
@@ -0,0 +1,34 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/fluid/inference/analysis/node.h"
+
+#include
+
+namespace paddle {
+namespace inference {
+namespace analysis {
+
+TEST(Node, Attr) {
+ // Node is an abstract class, use Value instead for they share the same Attr
+ // logic.
+ NodeMap nodes;
+ auto* node = nodes.Create(Node::Type::kValue);
+ node->attr("v0").Int32() = 2008;
+ ASSERT_EQ(node->attr("v0").Int32(), 2008);
+}
+
+} // namespace analysis
+} // namespace inference
+} // namespace paddle
diff --git a/paddle/fluid/inference/engine.h b/paddle/fluid/inference/engine.h
index de0375551e..ce2b816171 100644
--- a/paddle/fluid/inference/engine.h
+++ b/paddle/fluid/inference/engine.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include "paddle/fluid/framework/framework.pb.h"
namespace paddle {
@@ -58,8 +59,8 @@ class EngineBase {
struct Buffer {
void* buffer{nullptr}; // buffer should be allocated only once.
- int max_size; // buffer allocated space.
- int size; // data size.
+ size_t max_size; // buffer allocated space.
+ size_t size; // data size.
DeviceType device{DeviceType::UNK}; // tells which device this buffer is on.
};
diff --git a/paddle/fluid/inference/tensorrt/CMakeLists.txt b/paddle/fluid/inference/tensorrt/CMakeLists.txt
index 677b3e04af..b52d083f28 100644
--- a/paddle/fluid/inference/tensorrt/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/CMakeLists.txt
@@ -1,5 +1,4 @@
nv_library(tensorrt_engine SRCS engine.cc DEPS framework_proto)
nv_test(test_tensorrt SRCS test_tensorrt.cc DEPS dynload_cuda device_context dynamic_loader)
nv_test(test_tensorrt_engine SRCS test_engine.cc DEPS dynload_cuda tensorrt_engine)
-
add_subdirectory(convert)
diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
index 5178c54c08..4fb4511d99 100644
--- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
@@ -1,4 +1,4 @@
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES})
-nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc
+nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc
index 543784289c..6297051e5a 100644
--- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc
@@ -21,15 +21,18 @@ namespace tensorrt {
class ReluOpConverter : public OpConverter {
public:
ReluOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
+ // Here the two nullptr looks strange, that's because the
+ // framework::OpDesc's constructor is strange.
+ framework::OpDesc op_desc(op, nullptr, nullptr);
LOG(INFO) << "convert a fluid relu op to tensorrt activation layer whose "
"type is Relu";
const nvinfer1::ITensor* input_tensor =
- engine_->GetITensor(op.Input("X")[0]);
+ engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast(input_tensor),
nvinfer1::ActivationType::kRELU);
- engine_->SetITensor(op.Output("Out")[0], layer->getOutput(0));
+ engine_->SetITensor(op_desc.Output("Out")[0], layer->getOutput(0));
}
};
diff --git a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
index 431500b90e..209936c3ba 100644
--- a/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
@@ -21,7 +21,7 @@ namespace tensorrt {
class Conv2dOpConverter : public OpConverter {
public:
Conv2dOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO)
<< "convert a fluid conv2d op to tensorrt conv layer without bias";
}
diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.cc b/paddle/fluid/inference/tensorrt/convert/io_converter.cc
index 32e8631fde..854f434d93 100644
--- a/paddle/fluid/inference/tensorrt/convert/io_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/io_converter.cc
@@ -23,26 +23,42 @@ namespace tensorrt {
using platform::is_gpu_place;
using platform::is_cpu_place;
-class DefaultInputConverter : public EngineInputConverter {
+class DefaultIOConverter : public EngineIOConverter {
public:
- DefaultInputConverter() {}
+ DefaultIOConverter() {}
// NOTE out is GPU memory.
virtual void operator()(const LoDTensor& in, void* out,
size_t max_size) override {
PADDLE_ENFORCE(out != nullptr);
- PADDLE_ENFORCE_LE(in.memory_size(), max_size);
+ PADDLE_ENFORCE(stream_ != nullptr);
const auto& place = in.place();
+ size_t size = in.memory_size();
+ PADDLE_ENFORCE_LE(size, max_size);
if (is_cpu_place(place)) {
- PADDLE_ENFORCE(stream_ != nullptr);
- PADDLE_ENFORCE_EQ(0,
- cudaMemcpyAsync(out, in.data(), in.memory_size(),
- cudaMemcpyHostToDevice, *stream_));
-
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size,
+ cudaMemcpyHostToDevice, *stream_));
} else if (is_gpu_place(place)) {
- PADDLE_ENFORCE_EQ(0,
- cudaMemcpyAsync(out, in.data(), in.memory_size(),
- cudaMemcpyHostToHost, *stream_));
-
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out, in.data(), size,
+ cudaMemcpyDeviceToDevice, *stream_));
+ } else {
+ PADDLE_THROW("Unknown device for converter");
+ }
+ cudaStreamSynchronize(*stream_);
+ }
+ // NOTE in is GPU memory.
+ virtual void operator()(const void* in, LoDTensor* out,
+ size_t max_size) override {
+ PADDLE_ENFORCE(in != nullptr);
+ PADDLE_ENFORCE(stream_ != nullptr);
+ const auto& place = out->place();
+ size_t size = out->memory_size();
+ PADDLE_ENFORCE_LE(size, max_size);
+ if (is_cpu_place(place)) {
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size,
+ cudaMemcpyDeviceToHost, *stream_));
+ } else if (is_gpu_place(place)) {
+ PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(out->data(), in, size,
+ cudaMemcpyDeviceToDevice, *stream_));
} else {
PADDLE_THROW("Unknown device for converter");
}
@@ -50,7 +66,8 @@ class DefaultInputConverter : public EngineInputConverter {
}
};
-REGISTER_TENSORRT_INPUT_CONVERTER(default, DefaultInputConverter);
+// fluid LodTensor <-> tensorrt ITensor
+REGISTER_TENSORRT_IO_CONVERTER(default, DefaultIOConverter);
} // namespace tensorrt
} // namespace inference
diff --git a/paddle/fluid/inference/tensorrt/convert/io_converter.h b/paddle/fluid/inference/tensorrt/convert/io_converter.h
index 8972dae92b..71c48e085d 100644
--- a/paddle/fluid/inference/tensorrt/convert/io_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/io_converter.h
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
+#include
#include
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/utils/singleton.h"
@@ -25,43 +26,57 @@ namespace tensorrt {
using framework::LoDTensor;
/*
- * Convert Input from Fluid to an Engine.
- * TensorRT's ITensor follows row major, NCHW. Fluid is also row major, so in
- * most cases just need to copy the data.
+ * Convert Input from Fluid to TensorRT Engine.
+ * Convert Output from TensorRT Engine to Fluid.
+ *
+ * Note that TensorRT's ITensor follows row major, NCHW. Fluid is also row
+ * major,
+ * so in the default case just need to copy the data.
*/
-class EngineInputConverter {
+class EngineIOConverter {
public:
- EngineInputConverter() {}
+ EngineIOConverter() {}
virtual void operator()(const LoDTensor& in, void* out, size_t max_size) {}
+ virtual void operator()(const void* in, LoDTensor* out, size_t max_size) {}
void SetStream(cudaStream_t* stream) { stream_ = stream; }
- static void Run(const std::string& in_op_type, const LoDTensor& in, void* out,
- size_t max_size, cudaStream_t* stream) {
+ static void ConvertInput(const std::string& op_type, const LoDTensor& in,
+ void* out, size_t max_size, cudaStream_t* stream) {
PADDLE_ENFORCE(stream != nullptr);
- auto* converter = Registry::Lookup(
- in_op_type, "default" /* default_type */);
+ auto* converter = Registry::Lookup(
+ op_type, "default" /* default_type */);
PADDLE_ENFORCE_NOT_NULL(converter);
converter->SetStream(stream);
(*converter)(in, out, max_size);
}
- virtual ~EngineInputConverter() {}
+ static void ConvertOutput(const std::string& op_type, const void* in,
+ LoDTensor* out, size_t max_size,
+ cudaStream_t* stream) {
+ PADDLE_ENFORCE(stream != nullptr);
+ auto* converter = Registry::Lookup(
+ op_type, "default" /* default_type */);
+ PADDLE_ENFORCE_NOT_NULL(converter);
+ converter->SetStream(stream);
+ (*converter)(in, out, max_size);
+ }
+
+ virtual ~EngineIOConverter() {}
protected:
cudaStream_t* stream_{nullptr};
};
+#define REGISTER_TENSORRT_IO_CONVERTER(op_type__, Converter__) \
+ struct trt_io_##op_type__##_converter { \
+ trt_io_##op_type__##_converter() { \
+ Registry::Register(#op_type__); \
+ } \
+ }; \
+ trt_io_##op_type__##_converter trt_io_##op_type__##_converter__;
+
} // namespace tensorrt
} // namespace inference
} // namespace paddle
-
-#define REGISTER_TENSORRT_INPUT_CONVERTER(in_op_type__, Converter__) \
- struct trt_input_##in_op_type__##_converter { \
- trt_input_##in_op_type__##_converter() { \
- ::paddle::inference::Registry::Register< \
- Converter__>(#in_op_type__); \
- } \
- }; \
- trt_input_##in_op_type__##_converter trt_input_##in_op_type__##_converter__;
diff --git a/paddle/fluid/inference/tensorrt/convert/mul_op.cc b/paddle/fluid/inference/tensorrt/convert/mul_op.cc
index f9834ab156..3ca58b139b 100644
--- a/paddle/fluid/inference/tensorrt/convert/mul_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/mul_op.cc
@@ -21,7 +21,7 @@ namespace tensorrt {
class MulOpConverter : public OpConverter {
public:
MulOpConverter() {}
- void operator()(const framework::OpDesc& op) override {
+ void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias";
}
};
diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h
index 77c788550b..abc9ebf472 100644
--- a/paddle/fluid/inference/tensorrt/convert/op_converter.h
+++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h
@@ -31,10 +31,10 @@ namespace tensorrt {
class OpConverter {
public:
OpConverter() {}
- virtual void operator()(const framework::OpDesc& op) {}
+ virtual void operator()(const framework::proto::OpDesc& op) {}
- void Run(const framework::OpDesc& op, TensorRTEngine* engine) {
- std::string type = op.Type();
+ void Run(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
+ std::string type = op.type();
auto* it = Registry::Lookup(type);
PADDLE_ENFORCE_NOT_NULL(it, "no OpConverter for optype [%s]", type);
it->SetEngine(engine);
@@ -42,14 +42,16 @@ class OpConverter {
}
// convert fluid op to tensorrt layer
- void ConvertOp(const framework::OpDesc& op, TensorRTEngine* engine) {
+ void ConvertOp(const framework::proto::OpDesc& op, TensorRTEngine* engine) {
OpConverter::Run(op, engine);
}
// convert fluid block to tensorrt network
- void ConvertBlock(const framework::BlockDesc& block, TensorRTEngine* engine) {
- for (auto op : block.AllOps()) {
- OpConverter::Run(*op, engine);
+ void ConvertBlock(const framework::proto::BlockDesc& block,
+ TensorRTEngine* engine) {
+ for (size_t i = 0; i < block.ops_size(); i++) {
+ const auto& op = block.ops(i);
+ OpConverter::Run(op, engine);
}
}
diff --git a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
index 23e3435c21..ec33f97c82 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_activation_op.cc
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
@@ -26,7 +27,7 @@ namespace paddle {
namespace inference {
namespace tensorrt {
-void Compare(float input, float expect) {
+void Compare(const std::string op_type, float input, float expect) {
framework::Scope scope;
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
@@ -35,6 +36,7 @@ void Compare(float input, float expect) {
auto x_var = scope.Var("X");
auto x_tensor = x_var->GetMutable();
x_tensor->Resize({1, 1});
+ x_tensor->mutable_data(place);
std::vector init;
init.push_back(input);
framework::TensorFromVector(init, ctx, x_tensor);
@@ -45,14 +47,15 @@ void Compare(float input, float expect) {
out_tensor->mutable_data(place);
framework::OpDesc op_desc;
- op_desc.SetType("relu");
+ op_desc.SetType(op_type);
op_desc.SetInput("X", {"X"});
op_desc.SetOutput("Out", {"Out"});
- auto relu_op = framework::OpRegistry::CreateOp(op_desc);
+ auto op = framework::OpRegistry::CreateOp(*op_desc.Proto());
// run fluid op
- relu_op->Run(scope, place);
+ op->Run(scope, place);
+ // get fluid output
std::vector out1;
framework::TensorToVector(*out_tensor, ctx, &out1);
@@ -63,21 +66,28 @@ void Compare(float input, float expect) {
engine->InitNetwork();
engine->DeclareInput("X", nvinfer1::DataType::kFLOAT,
nvinfer1::DimsCHW{1, 1, 1});
-
+ // convert op
OpConverter op_converter;
- op_converter.ConvertOp(op_desc, engine);
+ op_converter.ConvertOp(*op_desc.Proto(), engine);
engine->DeclareOutput("Out");
engine->FreezeNetwork();
- engine->SetInputFromCPU("X", &input, 1 * sizeof(float));
- // run tensorrt op
+ // convert LoDTensor to ITensor
+ size_t size = x_tensor->memory_size();
+ EngineIOConverter::ConvertInput(op_type, *x_tensor,
+ engine->buffer("X").buffer, size, &stream);
+ // run tensorrt Outp
engine->Execute(1);
-
- float out2;
- engine->GetOutputInCPU("Out", &out2, 1 * sizeof(float));
-
- ASSERT_EQ(out1[0], out2);
+ // convert ITensor to LoDTensor
+ EngineIOConverter::ConvertOutput(op_type, engine->buffer("Out").buffer,
+ out_tensor, size, &stream);
+ // get tensorrt output
+ std::vector out2;
+ framework::TensorToVector(*out_tensor, ctx, &out2);
+
+ // compare
+ ASSERT_EQ(out1[0], out2[0]);
ASSERT_EQ(out1[0], expect);
delete engine;
@@ -85,8 +95,8 @@ void Compare(float input, float expect) {
}
TEST(OpConverter, ConvertRelu) {
- Compare(1, 1); // relu(1) = 1
- Compare(-5, 0); // relu(-5) = 0
+ Compare("relu", 1, 1); // relu(1) = 1
+ Compare("relu", -5, 0); // relu(-5) = 0
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
index afcc516e6b..8f91309a0a 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_io_converter.cc
@@ -12,40 +12,63 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
+#include
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/inference/tensorrt/convert/io_converter.h"
-#include
-
namespace paddle {
namespace inference {
namespace tensorrt {
-class EngineInputConverterTester : public ::testing::Test {
- public:
- void SetUp() override { tensor.Resize({10, 10}); }
+void IOConverterTester(const platform::DeviceContext& ctx) {
+ cudaStream_t stream;
+ ASSERT_EQ(0, cudaStreamCreate(&stream));
- framework::LoDTensor tensor;
-};
+ // init fluid in_tensor
+ framework::LoDTensor in_tensor;
+ in_tensor.Resize({10, 10});
+ auto place = ctx.GetPlace();
+ in_tensor.mutable_data(place);
+ std::vector init;
+ for (int64_t i = 0; i < 10 * 10; ++i) {
+ init.push_back(i);
+ }
+ framework::TensorFromVector(init, ctx, &in_tensor);
-TEST_F(EngineInputConverterTester, DefaultCPU) {
+ // init tensorrt buffer
void* buffer;
- tensor.mutable_data(platform::CPUPlace());
- ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
+ size_t size = in_tensor.memory_size();
+ ASSERT_EQ(cudaMalloc(&buffer, size), 0);
- cudaStream_t stream;
- EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
- &stream);
+ // convert fluid in_tensor to tensorrt buffer
+ EngineIOConverter::ConvertInput("test", in_tensor, buffer, size, &stream);
+
+ // convert tensorrt buffer to fluid out_tensor
+ framework::LoDTensor out_tensor;
+ out_tensor.Resize({10, 10});
+ out_tensor.mutable_data(place);
+ EngineIOConverter::ConvertOutput("test", buffer, &out_tensor, size, &stream);
+
+ // compare in_tensor and out_tensor
+ std::vector result;
+ framework::TensorToVector(out_tensor, ctx, &result);
+ EXPECT_EQ(init.size(), result.size());
+ for (size_t i = 0; i < init.size(); i++) {
+ EXPECT_EQ(init[i], result[i]);
+ }
+ cudaStreamDestroy(stream);
}
-TEST_F(EngineInputConverterTester, DefaultGPU) {
- void* buffer;
- tensor.mutable_data(platform::CUDAPlace());
- ASSERT_EQ(cudaMalloc(&buffer, tensor.memory_size()), 0);
+TEST(EngineIOConverterTester, DefaultCPU) {
+ platform::CPUPlace place;
+ platform::CPUDeviceContext ctx(place);
+ IOConverterTester(ctx);
+}
- cudaStream_t stream;
- EngineInputConverter::Run("test", tensor, buffer, tensor.memory_size(),
- &stream);
+TEST(EngineIOConverterTester, DefaultGPU) {
+ platform::CUDAPlace place;
+ platform::CUDADeviceContext ctx(place);
+ IOConverterTester(ctx);
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
index aa5fb726f1..8d66543eb7 100644
--- a/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
+++ b/paddle/fluid/inference/tensorrt/convert/test_op_converter.cc
@@ -29,7 +29,7 @@ TEST(OpConverter, ConvertBlock) {
conv2d_op->SetType("conv2d");
OpConverter converter;
- converter.ConvertBlock(*block, nullptr /*TensorRTEngine*/);
+ converter.ConvertBlock(*block->Proto(), nullptr /*TensorRTEngine*/);
}
} // namespace tensorrt
diff --git a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
index c4fd1e298b..60c761c528 100644
--- a/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
+++ b/paddle/fluid/inference/tests/book/test_inference_image_classification.cc
@@ -16,7 +16,6 @@ limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h"
-DEFINE_string(data_set, "cifar10", "Data set to test");
DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(fp16_dirname, "", "Directory of the float16 inference model.");
DEFINE_int32(batch_size, 1, "Batch size of input data");
@@ -35,19 +34,19 @@ TEST(inference, image_classification) {
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
+ const bool is_combined = false;
+ std::vector> feed_target_shapes =
+ GetFeedTargetShapes(dirname, is_combined);
+
paddle::framework::LoDTensor input;
// Use normilized image pixels as input data,
// which should be in the range [0.0, 1.0].
- if (FLAGS_data_set == "cifar10") {
- SetupTensor(&input, {FLAGS_batch_size, 3, 32, 32},
- static_cast(0), static_cast(1));
- } else if (FLAGS_data_set == "imagenet") {
- SetupTensor(&input, {FLAGS_batch_size, 3, 224, 224},
- static_cast(0), static_cast(1));
- } else {
- LOG(FATAL) << "Only cifar10 or imagenet is supported.";
- }
-
+ feed_target_shapes[0][0] = FLAGS_batch_size;
+ paddle::framework::DDim input_dims =
+ paddle::framework::make_ddim(feed_target_shapes[0]);
+ LOG(INFO) << input_dims;
+ SetupTensor(&input, input_dims, static_cast(0),
+ static_cast(1));
std::vector cpu_feeds;
cpu_feeds.push_back(&input);
@@ -60,7 +59,7 @@ TEST(inference, image_classification) {
LOG(INFO) << "--- CPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference(
- dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
+ dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat, is_combined);
LOG(INFO) << output1.dims();
}
@@ -73,7 +72,7 @@ TEST(inference, image_classification) {
LOG(INFO) << "--- GPU Runs: ---";
LOG(INFO) << "Batch size is " << FLAGS_batch_size;
TestInference(
- dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
+ dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat, is_combined);
LOG(INFO) << output2.dims();
if (!FLAGS_skip_cpu) {
diff --git a/paddle/fluid/inference/tests/test_helper.h b/paddle/fluid/inference/tests/test_helper.h
index af2a7a5620..b02e5c99f0 100644
--- a/paddle/fluid/inference/tests/test_helper.h
+++ b/paddle/fluid/inference/tests/test_helper.h
@@ -89,6 +89,50 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
}
+std::unique_ptr InitProgram(
+ paddle::framework::Executor* executor, paddle::framework::Scope* scope,
+ const std::string& dirname, const bool is_combined = false) {
+ std::unique_ptr inference_program;
+ if (is_combined) {
+ // All parameters are saved in a single file.
+ // Hard-coding the file names of program and parameters in unittest.
+ // The file names should be consistent with that used in Python API
+ // `fluid.io.save_inference_model`.
+ std::string prog_filename = "__model_combined__";
+ std::string param_filename = "__params_combined__";
+ inference_program =
+ paddle::inference::Load(executor, scope, dirname + "/" + prog_filename,
+ dirname + "/" + param_filename);
+ } else {
+ // Parameters are saved in separate files sited in the specified
+ // `dirname`.
+ inference_program = paddle::inference::Load(executor, scope, dirname);
+ }
+ return inference_program;
+}
+
+std::vector> GetFeedTargetShapes(
+ const std::string& dirname, const bool is_combined = false) {
+ auto place = paddle::platform::CPUPlace();
+ auto executor = paddle::framework::Executor(place);
+ auto* scope = new paddle::framework::Scope();
+
+ auto inference_program = InitProgram(&executor, scope, dirname, is_combined);
+ auto& global_block = inference_program->Block(0);
+
+ const std::vector& feed_target_names =
+ inference_program->GetFeedTargetNames();
+ std::vector> feed_target_shapes;
+ for (size_t i = 0; i < feed_target_names.size(); ++i) {
+ auto* var = global_block.FindVar(feed_target_names[i]);
+ std::vector var_shape = var->GetShape();
+ feed_target_shapes.push_back(var_shape);
+ }
+
+ delete scope;
+ return feed_target_shapes;
+}
+
template
void TestInference(const std::string& dirname,
const std::vector& cpu_feeds,
@@ -124,22 +168,7 @@ void TestInference(const std::string& dirname,
paddle::platform::RecordEvent record_event(
"init_program",
paddle::platform::DeviceContextPool::Instance().Get(place));
-
- if (is_combined) {
- // All parameters are saved in a single file.
- // Hard-coding the file names of program and parameters in unittest.
- // The file names should be consistent with that used in Python API
- // `fluid.io.save_inference_model`.
- std::string prog_filename = "__model_combined__";
- std::string param_filename = "__params_combined__";
- inference_program = paddle::inference::Load(
- &executor, scope, dirname + "/" + prog_filename,
- dirname + "/" + param_filename);
- } else {
- // Parameters are saved in separate files sited in the specified
- // `dirname`.
- inference_program = paddle::inference::Load(&executor, scope, dirname);
- }
+ inference_program = InitProgram(&executor, scope, dirname, is_combined);
}
// Disable the profiler and print the timing information
paddle::platform::DisableProfiler(paddle::platform::EventSortingKey::kDefault,
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index c14a2b7786..d38a9ce587 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -186,6 +186,11 @@ endif()
add_subdirectory(detail)
if(WITH_DISTRIBUTE)
+ if(WITH_GPU)
+ op_library(gen_nccl_id_op DEPS nccl_common)
+ else()
+ set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
+ endif()
set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
op_library(send_op DEPS ${DISTRIBUTE_DEPS})
@@ -202,8 +207,9 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor)
+ cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor)
else()
- set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op)
+ set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op)
endif()
op_library(cross_entropy_op DEPS cross_entropy)
diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc
index 661dfa69fe..ae60ab1532 100644
--- a/paddle/fluid/operators/detail/grpc_client.cc
+++ b/paddle/fluid/operators/detail/grpc_client.cc
@@ -52,7 +52,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
// stub context
SendProcessor* s = new SendProcessor(ch);
s->Prepare(var_h, time_out);
- s->response_call_back_ = NULL;
+ s->response_call_back_ = nullptr;
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
diff --git a/paddle/fluid/operators/detail/grpc_client.h b/paddle/fluid/operators/detail/grpc_client.h
index f6229b71bc..dabce7414d 100644
--- a/paddle/fluid/operators/detail/grpc_client.h
+++ b/paddle/fluid/operators/detail/grpc_client.h
@@ -57,7 +57,9 @@ void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor {
public:
- explicit BaseProcessor(std::shared_ptr ch) { context_ = NULL; }
+ explicit BaseProcessor(std::shared_ptr ch) {
+ context_ = nullptr;
+ }
virtual ~BaseProcessor() {}
@@ -105,7 +107,7 @@ class SendProcessor : public BaseProcessor {
::grpc::GenericStub stub_g_;
::grpc::ByteBuffer reply_;
- RequestSendCallBack response_call_back_ = NULL;
+ RequestSendCallBack response_call_back_ = nullptr;
};
typedef std::function
diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc
index e6ee28ea8d..d09f8479b7 100644
--- a/paddle/fluid/operators/detail/grpc_server.cc
+++ b/paddle/fluid/operators/detail/grpc_server.cc
@@ -306,7 +306,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
}
RequestPrefetch* prefetch =
new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_,
- dev_ctx_, executor_, program_, prefetch_ctx_);
+ dev_ctx_, executor_, program_, prefetch_ctx_.get());
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status();
}
diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h
index 7f9cae21cc..238aaa2963 100644
--- a/paddle/fluid/operators/detail/grpc_server.h
+++ b/paddle/fluid/operators/detail/grpc_server.h
@@ -47,6 +47,7 @@ class AsyncGRPCServer final {
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {}
+ ~AsyncGRPCServer() {}
void WaitServerReady();
void RunSyncUpdate();
@@ -63,8 +64,9 @@ class AsyncGRPCServer final {
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
- void SetPrefetchPreparedCtx(framework::ExecutorPrepareContext *prepared) {
- prefetch_ctx_ = prepared;
+ void SetPrefetchPreparedCtx(
+ std::unique_ptr prepared) {
+ prefetch_ctx_.reset(prepared.release());
}
int GetSelectedPort() const { return selected_port_; }
@@ -115,7 +117,7 @@ class AsyncGRPCServer final {
std::unique_ptr t_get_;
std::unique_ptr t_prefetch_;
- framework::ExecutorPrepareContext *prefetch_ctx_;
+ std::unique_ptr prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc
index 25b95d608d..b8db0ad987 100644
--- a/paddle/fluid/operators/detail/grpc_server_test.cc
+++ b/paddle/fluid/operators/detail/grpc_server_test.cc
@@ -100,7 +100,7 @@ void StartServer(const std::string& endpoint) {
InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program);
- rpc_service_->SetPrefetchPreparedCtx(prepared.get());
+ rpc_service_->SetPrefetchPreparedCtx(std::move(prepared));
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);
diff --git a/paddle/fluid/operators/detail/send_recv.proto b/paddle/fluid/operators/detail/send_recv.proto
index fffa9ae7a4..9478c5702b 100644
--- a/paddle/fluid/operators/detail/send_recv.proto
+++ b/paddle/fluid/operators/detail/send_recv.proto
@@ -32,6 +32,7 @@ service SendRecvService {
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
+ NCCL_ID = 2;
}
// NOTICE(gongwb):don't modify this proto if you are not
diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc
index 1a8a1af20f..07c43554bc 100644
--- a/paddle/fluid/operators/detail/sendrecvop_utils.cc
+++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc
@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
+#ifdef PADDLE_WITH_CUDA
+#include
+#endif
#include
#include // NOLINT
@@ -129,6 +132,10 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} else if (var->IsType()) {
request.set_type(::sendrecv::SELECTED_ROWS);
GetSelectedRowsPayload(var, ctx, &request, &payload, &payload_size);
+#ifdef PADDLE_WITH_CUDA
+ } else if (var->IsType()) {
+ request.set_type(::sendrecv::NCCL_ID);
+#endif
} else {
PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name());
@@ -149,6 +156,24 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void* buf = buffer.get();
ProtoEncodeHelper e(static_cast(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size()));
+// NCCLID is copied directly to the message, return bytebuffer
+// with only one slice if serializing NCCLID.
+#ifdef PADDLE_WITH_CUDA
+ if (var->IsType()) {
+ e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
+ NCCL_UNIQUE_ID_BYTES);
+ const ncclUniqueId& uid = var->Get();
+ e.WriteRawBytes(std::string(uid.internal, NCCL_UNIQUE_ID_BYTES));
+
+ // for serialize NCCL_ID
+ ::grpc::Slice slices(e.size());
+ memcpy(const_cast(slices.begin()), e.data(), e.size());
+ ::grpc::ByteBuffer tmp(&slices, 1);
+ msg->Swap(&tmp);
+ return;
+ }
+#endif
+
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
diff --git a/paddle/fluid/operators/detail/variable_response.cc b/paddle/fluid/operators/detail/variable_response.cc
index 99602a05d0..462e303096 100644
--- a/paddle/fluid/operators/detail/variable_response.cc
+++ b/paddle/fluid/operators/detail/variable_response.cc
@@ -17,6 +17,9 @@
#include
#include
#include
+#ifdef PADDLE_WITH_CUDA
+#include
+#endif
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
@@ -368,7 +371,8 @@ int VariableResponse::Parse(Source* source) {
}
case sendrecv::VariableMessage::kSerializedFieldNumber: {
PADDLE_ENFORCE((meta_.type() == sendrecv::SELECTED_ROWS ||
- meta_.type() == sendrecv::LOD_TENSOR) &&
+ meta_.type() == sendrecv::LOD_TENSOR ||
+ meta_.type() == sendrecv::NCCL_ID) &&
meta_.varname() != "",
"meta info should be got first!");
@@ -378,6 +382,22 @@ int VariableResponse::Parse(Source* source) {
return tag;
}
+ if (meta_.type() == sendrecv::NCCL_ID) {
+#ifdef PADDLE_WITH_CUDA
+ auto* var = scope_->FindVar(meta_.varname());
+ if (var != nullptr) {
+ ncclUniqueId* id = var->GetMutable();
+ if (!ReadRaw(&input, *dev_ctx_, platform::CPUPlace(), id->internal,
+ num_bytes)) {
+ return tag;
+ }
+ }
+ break;
+#else
+ PADDLE_THROW("Not compiled with CUDA!");
+#endif
+ }
+
framework::DDim dims = GetDims(meta_.dims());
if (meta_.type() == sendrecv::LOD_TENSOR) {
PADDLE_ENFORCE(meta_.lod_size() >= 0,
diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc
new file mode 100644
index 0000000000..a5678f6346
--- /dev/null
+++ b/paddle/fluid/operators/gen_nccl_id_op.cc
@@ -0,0 +1,128 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include
+#include
+
+#include "paddle/fluid/framework/executor.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/framework/threadpool.h"
+#include "paddle/fluid/operators/detail/grpc_client.h"
+#include "paddle/fluid/operators/detail/grpc_server.h"
+#include "paddle/fluid/platform/nccl_helper.h"
+
+namespace paddle {
+namespace operators {
+
+class GenNCCLIdOp : public framework::OperatorBase {
+ public:
+ GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
+ const framework::VariableNameMap& outputs,
+ const framework::AttributeMap& attrs)
+ : OperatorBase(type, inputs, outputs, attrs) {}
+
+ void RunImpl(const framework::Scope& scope,
+ const platform::Place& dev_place) const override {
+ platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+ // put nccl id in CPUPlace
+ auto& dev_ctx = *pool.Get(platform::CPUPlace());
+ int trainer_id = Attr("trainer_id");
+ framework::Scope& local_scope = scope.NewScope();
+
+ if (trainer_id == 0) {
+ GenerateAndSend(&local_scope, dev_ctx);
+ } else {
+ GetIdByServer(&local_scope, dev_ctx);
+ }
+ }
+
+ private:
+ void GenerateAndSend(framework::Scope* scope,
+ const platform::DeviceContext& dev_ctx) const {
+ auto var = scope->FindVar(NCCL_ID_VARNAME);
+ PADDLE_ENFORCE_NOT_NULL(var);
+ auto id = var->GetMutable();
+ PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
+
+ std::vector endpoint_list =
+ Attr>("endpoint_list");
+ detail::RPCClient client;
+ for (auto& ep : endpoint_list) {
+ VLOG(3) << "sending nccl id to " << ep;
+ client.AsyncSendVariable(ep, dev_ctx, *scope, NCCL_ID_VARNAME);
+ }
+ client.Wait();
+ VLOG(3) << "sending completed...";
+ }
+
+ void GetIdByServer(framework::Scope* scope,
+ const platform::DeviceContext& dev_ctx) const {
+ std::string endpoint = Attr("endpoint");
+ // NOTE: Can not use unique_ptr here because the default
+ // deleter will call GRPC Server's base class's dtor and
+ // that will cause a wired crash.
+ detail::AsyncGRPCServer rpc_service(endpoint, true);
+ framework::ProgramDesc empty_program;
+ framework::Executor executor(dev_ctx.GetPlace());
+ rpc_service.SetScope(scope);
+ rpc_service.SetDevCtx(&dev_ctx);
+ rpc_service.SetProgram(&empty_program);
+ rpc_service.SetExecutor(&executor);
+
+ std::thread server_thread(
+ std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
+ rpc_service.SetCond(0);
+ VLOG(3) << "start getting nccl id from trainer 0...";
+ auto recv = rpc_service.Get();
+ VLOG(3) << "got nccl id and stop server...";
+ rpc_service.ShutDown();
+ VLOG(3) << "rpc server stopped";
+ server_thread.join();
+ }
+};
+
+class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ void Make() override {
+ AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
+ AddComment(R"DOC(
+GenNCCLId operator
+
+For trainer 0: generate a new UniqueId and send it to all the other trainers.
+For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
+)DOC");
+ AddAttr("endpoint",
+ "(string), e.g. 127.0.0.1:6175 "
+ "current listen endpoint");
+ AddAttr>(
+ "endpoint_list",
+ "['trainer1_ip:port', 'trainer2_ip:port', ...] "
+ "list of trainer endpoints start from trainer 1")
+ .SetDefault({});
+ AddAttr("trainer_id",
+ "(int default 0) "
+ "The index of the trainer in distributed training.")
+ .SetDefault(0);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+
+REGISTER_OPERATOR(gen_nccl_id, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker);
diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc
index a29e0cd52c..abc88d3eb1 100644
--- a/paddle/fluid/operators/listen_and_serv_op.cc
+++ b/paddle/fluid/operators/listen_and_serv_op.cc
@@ -322,8 +322,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
- rpc_service_->SetPrefetchPreparedCtx(prefetch_prepared.get());
- prefetch_prepared.release();
+ rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared));
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc
index b5522dd246..0522a94195 100644
--- a/paddle/fluid/operators/load_combine_op.cc
+++ b/paddle/fluid/operators/load_combine_op.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include
-
+#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
@@ -31,6 +31,7 @@ class LoadCombineOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto filename = Attr("file_path");
+ auto load_as_fp16 = Attr("load_as_fp16");
std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast(fin),
@@ -59,17 +60,25 @@ class LoadCombineOp : public framework::OperatorBase {
// Get data from fin to tensor
DeserializeFromStream(fin, tensor, dev_ctx);
- if (platform::is_gpu_place(place)) {
- // copy CPU to GPU
- framework::LoDTensor cpu_tensor;
- cpu_tensor.ShareDataWith(*tensor);
- cpu_tensor.set_lod(tensor->lod());
-
- // reset tensor
+ auto in_dtype = framework::ToDataType(tensor->type());
+ auto out_dtype =
+ load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
+
+ if (in_dtype != out_dtype) {
+ // convert to float16 tensor
+ auto in_kernel_type = framework::OpKernelType(in_dtype, place);
+ auto out_kernel_type = framework::OpKernelType(out_dtype, place);
+ framework::LoDTensor fp16_tensor;
+ // copy LoD info to the new tensor
+ fp16_tensor.set_lod(tensor->lod());
+ framework::TransDataType(in_kernel_type, out_kernel_type, *tensor,
+ &fp16_tensor);
+
+ // reset output tensor
out_var->Clear();
tensor = out_var->GetMutable();
- tensor->set_lod(cpu_tensor.lod());
- TensorCopy(cpu_tensor, place, dev_ctx, tensor);
+ tensor->set_lod(fp16_tensor.lod());
+ tensor->ShareDataWith(fp16_tensor);
}
}
}
@@ -82,6 +91,13 @@ class LoadCombineOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"Out",
"(vector) The output LoDTensors that will be read from the input file.")
.AsDuplicable();
+ AddAttr(
+ "load_as_fp16",
+ "(boolean, default false)"
+ "If true, the tensor will be first loaded and then "
+ "converted to float16 data type. Otherwise, the tensor will be "
+ "directly loaded without data type conversion.")
+ .SetDefault(false);
AddAttr("file_path",
"(string) "
"LoDTensors will be loaded from \"file_path\".")
diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/fluid/operators/math/sequence2batch.h
index 0abda999a5..62e6307ae9 100644
--- a/paddle/fluid/operators/math/sequence2batch.h
+++ b/paddle/fluid/operators/math/sequence2batch.h
@@ -64,18 +64,22 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch->lod();
- PADDLE_ENFORCE_GT(lods.size(), 2UL);
- PADDLE_ENFORCE_EQ(lods[1].size(),
- static_cast(lod_tensor.dims()[0]));
+ PADDLE_ENFORCE_GT(lods.size(), 2UL,
+ "The LoD of LoDTensor should inlcude at least 2-level "
+ "sequence information.");
+ PADDLE_ENFORCE_EQ(
+ lods[1].size(), static_cast(lod_tensor.dims()[0]),
+ "The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor to_batch;
to_batch(context, lod_tensor, lods[1], batch, true);
return;
}
auto lods = lod_tensor.lod();
- auto lod = lods[0];
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
+ auto lod = lods[0];
+
std::vector seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
@@ -157,9 +161,12 @@ class Batch2LoDTensorFunctor {
const framework::LoDTensor& batch,
framework::LoDTensor* lod_tensor) const {
auto in_lod = batch.lod();
- PADDLE_ENFORCE_GT(in_lod.size(), 2UL);
- PADDLE_ENFORCE_EQ(in_lod[1].size(),
- static_cast(lod_tensor->dims()[0]));
+ PADDLE_ENFORCE_GT(in_lod.size(), 2UL,
+ "The LoD of LoDTensor should inlcude at least 2-level "
+ "sequence information.");
+ PADDLE_ENFORCE_EQ(
+ in_lod[1].size(), static_cast(lod_tensor->dims()[0]),
+ "The LoD information should be consistent with the dims.");
CopyMatrixRowsFunctor to_seq;
to_seq(context, batch, in_lod[1], lod_tensor, false);
}
diff --git a/paddle/fluid/operators/reshape_op.h b/paddle/fluid/operators/reshape_op.h
index ccd7063fe6..3dd8c7c11e 100644
--- a/paddle/fluid/operators/reshape_op.h
+++ b/paddle/fluid/operators/reshape_op.h
@@ -92,14 +92,16 @@ class ReshapeOp : public framework::OperatorWithKernel {
}
if (unk_dim_idx != -1) {
- output_shape[unk_dim_idx] = -in_size / capacity;
- // in_size < 0 and is un-determinate in compile time, skip the check,
- // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
- // capacity = -24, in_size = -8, output_shape[0] = 0
- // the following check will fail.
if (in_size > 0) {
+ // in_size < 0 and is un-determinate in compile time, skip the check,
+ // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
+ // capacity = -24, in_size = -8, output_shape[0] = 0
+ // the following check will fail.
+ output_shape[unk_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[unk_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
+ } else {
+ output_shape[unk_dim_idx] = -1;
}
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
@@ -122,7 +124,10 @@ class ReshapeKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext &ctx) const {
auto *out = ctx.Output("Out");
auto *in = ctx.Input("X");
- auto *shape_tensor = ctx.Input("Shape");
+
+ auto *shape_tensor = ctx.HasInput("Shape")
+ ? ctx.Input("Shape")
+ : nullptr;
framework::DDim out_dims = out->dims();
diff --git a/paddle/fluid/operators/save_load_combine_op_test.cc b/paddle/fluid/operators/save_load_combine_op_test.cc
index 47618c51d9..4743e0d949 100644
--- a/paddle/fluid/operators/save_load_combine_op_test.cc
+++ b/paddle/fluid/operators/save_load_combine_op_test.cc
@@ -139,8 +139,9 @@ TEST(SaveLoadCombineOp, CPU) {
CheckValues(expect4, actual4, expect_lod4, actual_lod4, numel4);
}
-// FP16 version of SaveLoadCombineOp Test
-TEST(SaveLoadCombineFP16Op, CPU) {
+// FP16 version of SaveLoadCombineOp Test, only altering the saving aspect
+// to save as FP16.
+TEST(SaveCombineFP16Op, CPU) {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
@@ -169,7 +170,7 @@ TEST(SaveLoadCombineFP16Op, CPU) {
20, 50, lod4, "test_var4", place, &scope, &expect_lod4);
// Set attributes
- std::string filename = "check_tensor_fp16.ls";
+ std::string filename = "check_tensor_fp16_save.ls";
paddle::framework::AttributeMap attrs;
attrs.insert({"file_path", std::string(filename)});
attrs.insert({"save_as_fp16", true});
@@ -216,6 +217,89 @@ TEST(SaveLoadCombineFP16Op, CPU) {
actual_lod4, numel4);
}
+// FP16 version of SaveLoadCombineOp Test, only altering the loading aspect
+// to load tensors with FP16 precision.
+TEST(LoadCombineFP16Op, CPU) {
+ paddle::framework::Scope scope;
+ paddle::platform::CPUPlace place;
+
+ std::vector lod1 = {0, 1, 2, 3, 10};
+ int numel1 = 100;
+ paddle::framework::LoD expect_lod1;
+ float* expect1 = CreateForSaveCombineOp(
+ 10, 10, lod1, "test_var1", place, &scope, &expect_lod1);
+
+ std::vector lod2 = {0, 2, 5, 10};
+ int numel2 = 200;
+ paddle::framework::LoD expect_lod2;
+ float* expect2 = CreateForSaveCombineOp(
+ 10, 20, lod2, "test_var2", place, &scope, &expect_lod2);
+
+ std::vector lod3 = {0, 20};
+ int numel3 = 4000;
+ paddle::framework::LoD expect_lod3;
+ float* expect3 = CreateForSaveCombineOp(
+ 20, 200, lod3, "test_var3", place, &scope, &expect_lod3);
+
+ std::vector lod4 = {0, 1, 20};
+ int numel4 = 1000;
+ paddle::framework::LoD expect_lod4;
+ float* expect4 = CreateForSaveCombineOp(
+ 20, 50, lod4, "test_var4", place, &scope, &expect_lod4);
+
+ // Set attributes
+ std::string filename = "check_tensor_fp16_load.ls";
+ paddle::framework::AttributeMap attrs;
+ attrs.insert({"file_path", std::string(filename)});
+
+ // Run the save_combine_op
+ auto save_combine_op = paddle::framework::OpRegistry::CreateOp(
+ "save_combine",
+ {{"X", {"test_var1", "test_var2", "test_var3", "test_var4"}}}, {}, attrs);
+ save_combine_op->Run(scope, place);
+
+ // Set up output vars
+ auto load_var1 = scope.Var("out_var1");
+ auto load_var2 = scope.Var("out_var2");
+ auto load_var3 = scope.Var("out_var3");
+ auto load_var4 = scope.Var("out_var4");
+
+ attrs.insert({"load_as_fp16", true});
+ // Run the load_combine_op
+ auto load_combine_op = paddle::framework::OpRegistry::CreateOp(
+ "load_combine", {},
+ {{"Out", {"out_var1", "out_var2", "out_var3", "out_var4"}}}, attrs);
+ load_combine_op->Run(scope, place);
+
+ auto* target1 = load_var1->GetMutable