Merge branch 'develop' of github.com:PaddlePaddle/Paddle into parallel_graph_mode

revert-15207-remove_op_handle_lock_and_fix_var
Yancey1989 7 years ago
commit 2dda19f756

@ -166,6 +166,8 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator # Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif() endif()

@ -194,6 +194,8 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))

@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache) shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
@ -191,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(tuple_test SRCS tuple_test.cc ) cc_test(tuple_test SRCS tuple_test.cc )

@ -33,11 +33,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit(); CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) { for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) { if (name == use_slots_[i]) {
if (use_slots_is_dense_[i]) { feed_vec_[i] = var->GetMutable<LoDTensor>();
feed_vec_[i] = MixTensor(var->GetMutable<Tensor>());
} else {
feed_vec_[i] = MixTensor(var->GetMutable<LoDTensor>());
}
} }
} }
} }
@ -301,6 +297,7 @@ bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
"the data, please check if the data contains unresolvable " "the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s", "characters.\nplease check this error line: %s",
str); str);
if (idx != -1) { if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]); (*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float if ((*instance)[idx].GetType()[0] == 'f') { // float
@ -337,6 +334,7 @@ void MultiSlotDataFeed::AddInstanceToInsVec(
(*ins_vec)[i].InitOffset(); (*ins_vec)[i].InitOffset();
} }
} }
for (size_t i = 0; i < instance.size(); ++i) { for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]); (*ins_vec)[i].AddIns(instance[i]);
} }
@ -348,36 +346,25 @@ void MultiSlotDataFeed::PutToFeedVec(
const auto& type = ins_vec[i].GetType(); const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset(); const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back()); int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData(); const auto& feasign = ins_vec[i].GetFloatData();
if (feed_vec_[i].IsDense()) { float* tensor_ptr = feed_vec_[i]->mutable_data<float>(
int size_in_each_batch = total_instance / batch_size_;
float* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<float>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else {
float* tensor_ptr = feed_vec_[i].GetLoDTensor()->mutable_data<float>(
{total_instance, 1}, platform::CPUPlace()); {total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
}
} else if (type[0] == 'u') { // uint64 } else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle // no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data(); const auto& feasign = ins_vec[i].GetUint64Data();
if (feed_vec_[i].IsDense()) { int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
int size_in_each_batch = total_instance / batch_size_;
int64_t* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data<int64_t>(
{batch_size_, size_in_each_batch}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
} else {
int64_t* tensor_ptr =
feed_vec_[i].GetLoDTensor()->mutable_data<int64_t>(
{total_instance, 1}, platform::CPUPlace()); {total_instance, 1}, platform::CPUPlace());
memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t)); memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t));
LoD data_lod{offset};
feed_vec_[i].GetLoDTensor()->set_lod(data_lod);
} }
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int dim = total_instance / batch_size_;
feed_vec_[i]->Resize({batch_size_, dim});
} }
} }
} }

@ -30,35 +30,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Pack Tensor type and LoDTensor type into MixTensor type, in order
// to record either Tensor or LoDTensor information at the same time.
class MixTensor {
public:
MixTensor() {}
explicit MixTensor(LoDTensor* lodtensor) {
is_dense_ = false;
lodtensor_ = lodtensor;
}
explicit MixTensor(Tensor* tensor) {
is_dense_ = true;
tensor_ = tensor;
}
bool IsDense() { return is_dense_; }
LoDTensor* GetLoDTensor() {
PADDLE_ENFORCE(!is_dense_, "Let a dense var return a LoDTensor ptr.");
return lodtensor_;
}
Tensor* GetTensor() {
PADDLE_ENFORCE(is_dense_, "Let a sparse var return a Tensor ptr.");
return tensor_;
}
private:
bool is_dense_;
LoDTensor* lodtensor_;
Tensor* tensor_;
};
// DataFeed is the base virtual class for all ohther DataFeeds. // DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer. // It is used to read files and parse the data for subsequent trainer.
// Example: // Example:
@ -133,7 +104,7 @@ class DataFeed {
use_slots_index_; // -1: not used; >=0: the index of use_slots_ use_slots_index_; // -1: not used; >=0: the index of use_slots_
// The data read by DataFeed will be stored here // The data read by DataFeed will be stored here
std::vector<MixTensor> feed_vec_; std::vector<LoDTensor*> feed_vec_;
// the batch size defined by user // the batch size defined by user
int default_batch_size_; int default_batch_size_;

@ -152,21 +152,15 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
const auto& multi_slot_desc = data_feed_desc.multi_slot_desc(); const auto& multi_slot_desc = data_feed_desc.multi_slot_desc();
std::map<std::string, const paddle::framework::LoDTensor*> std::map<std::string, const paddle::framework::LoDTensor*>
lodtensor_targets; lodtensor_targets;
std::map<std::string, const paddle::framework::Tensor*> tensor_targets;
for (int i = 0; i < multi_slot_desc.slots_size(); ++i) { for (int i = 0; i < multi_slot_desc.slots_size(); ++i) {
const auto& slot = multi_slot_desc.slots(i); const auto& slot = multi_slot_desc.slots(i);
if (slot.is_used()) { if (slot.is_used()) {
const auto& name = slot.name(); const auto& name = slot.name();
readers[idx]->AddFeedVar(scope->Var(name), name); readers[idx]->AddFeedVar(scope->Var(name), name);
if (slot.is_dense()) {
tensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::Tensor>();
} else {
lodtensor_targets[name] = lodtensor_targets[name] =
&scope->FindVar(name)->Get<paddle::framework::LoDTensor>(); &scope->FindVar(name)->Get<paddle::framework::LoDTensor>();
} }
} }
}
readers[idx]->Start(); readers[idx]->Start();
while (readers[idx]->Next()) { while (readers[idx]->Next()) {
int index = 0; int index = 0;
@ -175,8 +169,9 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
if (!slot.is_used()) { if (!slot.is_used()) {
continue; continue;
} }
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.is_dense()) { // dense branch if (slot.is_dense()) { // dense branch
const paddle::framework::Tensor* tens = tensor_targets[slot.name()];
if (slot.type() == "uint64") { if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>(); const int64_t* data = tens->data<int64_t>();
int batch_size = tens->dims()[0]; int batch_size = tens->dims()[0];
@ -202,8 +197,6 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
PADDLE_THROW("Error type in proto file."); PADDLE_THROW("Error type in proto file.");
} }
} else { // sparse branch } else { // sparse branch
const paddle::framework::LoDTensor* tens =
lodtensor_targets[slot.name()];
if (slot.type() == "uint64") { if (slot.type() == "uint64") {
const int64_t* data = tens->data<int64_t>(); const int64_t* data = tens->data<int64_t>();
for (size_t i = 0; i < tens->NumElements(); ++i) { for (size_t i = 0; i < tens->NumElements(); ++i) {

@ -15,14 +15,26 @@ cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_ro
if(WITH_GPU) if(WITH_GPU)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor) dynload_cuda variable_visitor)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim dynload_cuda) if(WITH_DISTRIBUTE)
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor sendrecvop_grpc)
else()
nv_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim dynload_cuda selected_rows_functor)
endif()
nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) nv_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
else() else()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor) variable_visitor)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope ddim) if(WITH_DISTRIBUTE)
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor sendrecvop_grpc)
else()
cc_library(reduce_op_handle SRCS reduce_op_handle.cc DEPS op_handle_base variable_visitor scope
ddim selected_rows_functor)
endif()
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) cc_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
endif() endif()

@ -58,6 +58,17 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
} }
} }
CollectiveContext *context = CollectiveContext::GetInstance();
context->endpoints_ = strategy_.trainers_endpoints_;
context->trainer_id_ = strategy_.trainer_id_;
PADDLE_ENFORCE(strategy_.trainer_id_ >= 0, "trainer_id_ >= 0");
if (strategy_.trainer_id_ > 0) {
PADDLE_ENFORCE((unsigned)(strategy_.trainer_id_) <
strategy_.trainers_endpoints_.size(),
"trainer_id_ < endpoints_ size");
}
VLOG(1) << "CollectiveContext:" << context->String();
// Convert graph to run on multi-devices. // Convert graph to run on multi-devices.
auto multi_devices_pass = AppendPass("multi_devices_pass"); auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
@ -135,7 +146,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx); pass->SetNotOwned<platform::NCCLContextMap>("nccl_ctxs", nctx);
#endif #endif
} else if (pass->Type() == "sequential_execution_pass") { } else if (pass->Type() == "sequential_execution_pass") {
VLOG(1) << "set enable_sequential_execution:" LOG(INFO) << "set enable_sequential_execution:"
<< enable_sequential_execution_; << enable_sequential_execution_;
pass->Erase(kAllOpDescs); pass->Erase(kAllOpDescs);
@ -143,7 +154,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
kAllOpDescs, kAllOpDescs,
new std::vector<OpDesc *>(main_program.Block(0).AllOps())); new std::vector<OpDesc *>(main_program.Block(0).AllOps()));
} else if (pass->Type() == "all_reduce_deps_pass") { } else if (pass->Type() == "all_reduce_deps_pass") {
VLOG(1) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this) LOG(INFO) << "SeqOnlyAllReduceOps:" << SeqOnlyAllReduceOps(*this)
<< ", num_trainers:" << num_trainers_; << ", num_trainers:" << num_trainers_;
pass->Erase(kAllOpDescs); pass->Erase(kAllOpDescs);

@ -74,6 +74,8 @@ struct BuildStrategy {
bool fuse_broadcast_op_{false}; bool fuse_broadcast_op_{false};
int num_trainers_{1}; int num_trainers_{1};
int trainer_id_{0};
std::vector<std::string> trainers_endpoints_;
bool remove_unnecessary_lock_{false}; bool remove_unnecessary_lock_{false};
// NOTE: // NOTE:

@ -53,7 +53,7 @@ struct ReduceLoDTensor {
} }
}; };
inline void GatherSelectedRows( inline void GatherLocalSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_, const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places, const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes, const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,

@ -16,6 +16,12 @@
#include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/container_cast.h"
#include "paddle/fluid/framework/details/reduce_and_gather.h" #include "paddle/fluid/framework/details/reduce_and_gather.h"
#include "paddle/fluid/framework/details/variable_visitor.h" #include "paddle/fluid/framework/details/variable_visitor.h"
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/collective_client.h"
#include "paddle/fluid/operators/distributed/collective_server.h"
#include "paddle/fluid/operators/distributed/request_handler.h"
#endif
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
DEFINE_bool( DEFINE_bool(
@ -26,6 +32,112 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
std::once_flag CollectiveContext::init_flag_;
std::unique_ptr<CollectiveContext> CollectiveContext::context_;
static inline std::string GetRemoteVarName(const std::string &var_name,
int trainer_id) {
return string::Sprintf("%s_merged_tmp@trainer_%d", var_name, trainer_id);
}
void ReduceOpHandle::Wait(
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes) {
// TODO(gongwb): use event wait?
for (auto &dev_ctx : dev_ctxes) {
dev_ctx.second->Wait();
}
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
template <typename DevCtx, typename DataType>
void ReduceOpHandle::GatherSelectedRows(
const std::vector<const SelectedRows *> &src_selected_rows,
const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
VarHandle *out_var_handle, const platform::Place &out_place,
SelectedRows *dst_selected_rows) {
const CollectiveContext &collective_context =
*CollectiveContext::GetInstance();
// 1. gather local selected rows, merge them
std::string gathered_var_name = out_var_handle->name_ + "_gathered_tmp";
auto scope = local_scopes_.at(out_var_handle->scope_idx_);
auto gathered_var_mid = scope->Var(gathered_var_name);
auto gathered_select_rows =
gathered_var_mid->GetMutable<framework::SelectedRows>();
GatherLocalSelectedRows(src_selected_rows, in_places, dev_ctxes, out_place,
gathered_select_rows);
// FIXME(gongwb): remove this Wait.
Wait(dev_ctxes);
// merge them
auto merged_dev_ctx = dynamic_cast<DevCtx *>(dev_ctxes.at(out_place));
std::string merged_var_name =
GetRemoteVarName(out_var_handle->name_, collective_context.trainer_id_);
auto merged_select_rows =
scope->Var(merged_var_name)->GetMutable<SelectedRows>();
operators::math::scatter::MergeAdd<DevCtx, DataType> merge_func;
merge_func(*merged_dev_ctx, *gathered_select_rows, merged_select_rows);
// 2. start collective server if it doesn't exist
operators::distributed::CollectiveServer *server =
operators::distributed::CollectiveServer::GetInstance(
collective_context.endpoints_[collective_context.trainer_id_],
collective_context.endpoints_.size() - 1);
auto rpc_server = server->GetRPCServer();
rpc_server->RegisterVar(merged_var_name,
operators::distributed::kRequestGetMonomerVariable,
scope, merged_dev_ctx);
// 3. gather them from all remote nodes.
std::vector<const SelectedRows *> remote;
operators::distributed::CollectiveClient *client =
operators::distributed::CollectiveClient::GetInstance();
std::vector<operators::distributed::RemoteVar> vars;
for (unsigned int i = 0; i < collective_context.endpoints_.size(); i++) {
if (i == (unsigned)collective_context.trainer_id_) continue;
operators::distributed::RemoteVar var;
var.trainer_id_ = i;
var.var_name_ = GetRemoteVarName(out_var_handle->name_, i);
var.ep_ = collective_context.endpoints_[i];
vars.push_back(var);
VLOG(4) << "gather from:" << var.String();
}
// erase gathered vars
merged_dev_ctx->Wait();
scope->EraseVars(std::vector<std::string>{gathered_var_name});
PADDLE_ENFORCE(client->Gather(vars, &remote, *merged_dev_ctx, scope));
PADDLE_ENFORCE(remote.size() == vars.size());
// 4. merged local selected rows.
std::vector<const SelectedRows *> all;
all.resize(collective_context.endpoints_.size());
for (auto v : vars) {
all[v.trainer_id_] =
scope->FindVar(v.var_name_)->GetMutable<SelectedRows>();
}
all[collective_context.trainer_id_] = merged_select_rows;
merge_func(*merged_dev_ctx, all, dst_selected_rows);
rpc_server->WaitVarBarrier(merged_var_name);
rpc_server->ClearVar(merged_var_name);
// 5. clear mid vars
std::vector<std::string> tmp_vars{merged_var_name};
for (auto r : vars) {
tmp_vars.push_back(r.var_name_);
}
scope->EraseVars(tmp_vars);
}
#endif
void ReduceOpHandle::RunImpl() { void ReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second); platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
@ -90,8 +202,36 @@ void ReduceOpHandle::RunImpl() {
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
const CollectiveContext &collective_context =
*CollectiveContext::GetInstance();
VLOG(10) << "GatherSelectedRows CollectiveContext:"
<< collective_context.String();
// TODO(gongwb): add cpu support
if (collective_context.endpoints_.size() <= 1 ||
is_cpu_place(in_places[0]) || is_cpu_place(t_out_p)) {
GatherLocalSelectedRows(in_selected_rows, in_places, dev_ctxes_,
t_out_p,
out_var->GetMutable<framework::SelectedRows>());
return;
}
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
if (framework::IsType<const float>(in_selected_rows[0]->value().type())) {
GatherSelectedRows<platform::CUDADeviceContext, float>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>()); out_var->GetMutable<framework::SelectedRows>());
} else if (framework::IsType<const double>(
in_selected_rows[0]->value().type())) {
GatherSelectedRows<platform::CUDADeviceContext, double>(
in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p,
out_var->GetMutable<framework::SelectedRows>());
} else {
PADDLE_ENFORCE(false,
"only support double or float when gahter SelectedRows");
}
#endif
}); });
} else { } else {
std::vector<const LoDTensor *> lod_tensors = std::vector<const LoDTensor *> lod_tensors =

@ -30,6 +30,32 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct CollectiveContext {
std::vector<std::string> endpoints_;
int trainer_id_{0};
std::string String() const {
std::stringstream ss;
ss << "endpoints_:";
for (auto e : endpoints_) {
ss << e << ",";
}
ss << "trainer_id_:" << trainer_id_;
return ss.str();
}
static CollectiveContext *GetInstance() {
std::call_once(init_flag_,
[&]() { context_.reset(new CollectiveContext()); });
return context_.get();
}
private:
static std::once_flag init_flag_;
static std::unique_ptr<CollectiveContext> context_;
};
struct ReduceOpHandle : public OpHandleBase { struct ReduceOpHandle : public OpHandleBase {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
@ -64,6 +90,19 @@ struct ReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
#if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE
template <typename DevCtx, typename DataType>
void GatherSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
VarHandle *out_var_handle, const platform::Place &out_place,
SelectedRows *dst_selecte_rows);
#endif
void Wait(
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes);
template <typename T> template <typename T>
std::vector<const T *> GetInputValues( std::vector<const T *> GetInputValues(
const std::vector<VarHandle *> &in_var_handles, const std::vector<VarHandle *> &in_var_handles,

@ -97,7 +97,7 @@ void ExecutorThreadWorker::SetDevice() {
static unsigned concurrency_cap = std::thread::hardware_concurrency(); static unsigned concurrency_cap = std::thread::hardware_concurrency();
int thread_id = this->thread_id_; int thread_id = this->thread_id_;
if (thread_id < concurrency_cap) { if (static_cast<unsigned>(thread_id) < concurrency_cap) {
unsigned proc = thread_id; unsigned proc = thread_id;
cpu_set_t mask; cpu_set_t mask;

@ -46,14 +46,16 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
std::string type = is_conv3d() ? "conv3d" : "conv2d";
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
gpd.mutable_pattern() gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput() ->AsInput()
->assert_is_op_input("conv2d", "Input"); ->assert_is_op_input(type, "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input); conv_bias_pattern(conv_input, is_conv3d());
int found_conv_bias_count = 0; int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
@ -109,7 +111,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()})); desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()})); desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()})); desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType("conv2d"); desc.SetType(type);
for (auto& attr : conv->Op()->GetAttrMap()) { for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second); desc.SetAttr(attr.first, attr.second);
@ -135,3 +137,5 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
} // namespace paddle } // namespace paddle
REGISTER_PASS(conv_bias_mkldnn_fuse_pass, REGISTER_PASS(conv_bias_mkldnn_fuse_pass,
paddle::framework::ir::ConvBiasFusePass); paddle::framework::ir::ConvBiasFusePass);
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass);

@ -26,11 +26,19 @@ namespace ir {
class ConvBiasFusePass : public FusePassBase { class ConvBiasFusePass : public FusePassBase {
public: public:
virtual ~ConvBiasFusePass() {} virtual ~ConvBiasFusePass() {}
virtual bool is_conv3d() const { return false; }
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"conv_bias_mkldnn_fuse"}; const std::string name_scope_{"conv_bias_mkldnn_fuse"};
}; };
/*
* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp.
*/
class Conv3DBiasFusePass : public ConvBiasFusePass {
public:
bool is_conv3d() const override { return true; }
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -177,14 +177,13 @@ class Graph {
return nullptr; return nullptr;
} }
const ProgramDesc &program() const { return program_; }
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
void ResolveHazard( void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes); const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private: private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());

@ -1030,10 +1030,11 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
} }
PDNode *patterns::ConvBias::operator()( PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input) { paddle::framework::ir::PDNode *conv_input, bool is_conv3d) {
std::string type = is_conv3d ? "conv3d" : "conv2d";
// Create Operators // Create Operators
conv_input->assert_is_op_input("conv2d", "Input"); conv_input->assert_is_op_input(type, "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type);
auto *eltiwse_op = auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables // Create variables
@ -1041,11 +1042,11 @@ PDNode *patterns::ConvBias::operator()(
auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) auto *conv_weight_var = pattern->NewNode(conv_weight_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("conv2d", "Filter"); ->assert_is_op_input(type, "Filter");
// intermediate variable, will be removed in the IR after fuse. // intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr()) auto *conv_out_var = pattern->NewNode(conv_out_repr())
->AsIntermediate() ->AsIntermediate()
->assert_is_only_output_of_op("conv2d") ->assert_is_only_output_of_op(type)
->assert_is_op_input("elementwise_add"); ->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add // Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())

@ -623,7 +623,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
struct ConvBias : public PatternBase { struct ConvBias : public PatternBase {
ConvBias(PDPattern* pattern, const std::string& name_scope) ConvBias(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bias") {} : PatternBase(pattern, name_scope, "conv_bias") {}
PDNode* operator()(PDNode* conv_input); PDNode* operator()(PDNode* conv_input, bool is_conv3d = false);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(eltwise); PATTERN_DECL_NODE(eltwise);

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn_placement_pass.h" #include "paddle/fluid/framework/ir/mkldnn_placement_pass.h"
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -21,9 +22,16 @@ namespace ir {
std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl( std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Aplies MKL-DNN placement strategy."; VLOG(3) << "Aplies MKL-DNN placement strategy.";
const auto& op_types_list =
Get<std::unordered_set<std::string>>("mkldnn_enabled_op_types");
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) { if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) {
if (op_types_list.empty()) {
n->Op()->SetAttr("use_mkldnn", true); n->Op()->SetAttr("use_mkldnn", true);
} else if (std::find(op_types_list.begin(), op_types_list.end(),
n->Name()) != op_types_list.end()) {
n->Op()->SetAttr("use_mkldnn", true);
}
} }
} }
return graph; return graph;
@ -33,5 +41,5 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(mkldnn_placement_pass, REGISTER_PASS(mkldnn_placement_pass, paddle::framework::ir::MKLDNNPlacementPass)
paddle::framework::ir::MKLDNNPlacementPass); .RequirePassAttr("mkldnn_enabled_op_types");

@ -0,0 +1,54 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_kernel_type.h"
namespace paddle {
namespace framework {
size_t OpKernelType::Hash::operator()(const OpKernelType& key) const {
int cur_loc = 0;
int place = key.place_.which();
cur_loc += OpKernelType::kPlaceBits;
int data_type = static_cast<int>(key.data_type_) << cur_loc;
cur_loc += OpKernelType::kPrimaryDTypeBits;
int data_layout = static_cast<int>(key.data_layout_) << cur_loc;
cur_loc += OpKernelType::kLayoutBits;
int library_type = static_cast<int>(key.library_type_) << cur_loc;
cur_loc += OpKernelType::kLibBits;
int customized_value = key.customized_type_value_;
PADDLE_ENFORCE(customized_value < (1 << OpKernelType::kCustomizeBits));
customized_value = customized_value << cur_loc;
cur_loc += OpKernelType::kCustomizeBits;
PADDLE_ENFORCE(cur_loc < 64);
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type +
customized_value);
}
bool OpKernelType::operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_ &&
customized_type_value_ == o.customized_type_value_;
}
} // namespace framework
} // namespace paddle

@ -24,54 +24,55 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct OpKernelType { class OpKernelType {
struct Hash { public:
size_t operator()(const OpKernelType& key) const { constexpr static int kDefaultCustomizedTypeValue = 0;
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT;
int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2);
int library_type = static_cast<int>(key.library_type_)
<< (LEFT_SHIFT * 3);
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type);
}
};
// place, data_type, library_type kinds less than 2^8 // In total should be smaller than 64.
constexpr static int LEFT_SHIFT = 8; constexpr static int kPlaceBits = 4;
constexpr static int kPrimaryDTypeBits = 8;
proto::VarType::Type data_type_; constexpr static int kLayoutBits = 4;
DataLayout data_layout_; constexpr static int kLibBits = 4;
platform::Place place_; constexpr static int kCustomizeBits = 4;
LibraryType library_type_;
OpKernelType(proto::VarType::Type data_type, platform::Place place, OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain,
int customized_type_value = kDefaultCustomizedTypeValue)
: data_type_(data_type), : data_type_(data_type),
data_layout_(data_layout), data_layout_(data_layout),
place_(place), place_(place),
library_type_(library_type) {} library_type_(library_type),
customized_type_value_(customized_type_value) {}
OpKernelType(proto::VarType::Type data_type, OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain,
int customized_type_value = kDefaultCustomizedTypeValue)
: data_type_(data_type), : data_type_(data_type),
data_layout_(data_layout), data_layout_(data_layout),
place_(dev_ctx.GetPlace()), place_(dev_ctx.GetPlace()),
library_type_(library_type) {} library_type_(library_type),
customized_type_value_(customized_type_value) {}
virtual ~OpKernelType() {}
struct Hash {
size_t operator()(const OpKernelType& key) const;
};
size_t hash_key() const { return Hash()(*this); } size_t hash_key() const { return Hash()(*this); }
bool operator==(const OpKernelType& o) const { bool operator==(const OpKernelType& o) const;
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
bool operator!=(const OpKernelType& o) const { return !(*this == o); } bool operator!=(const OpKernelType& o) const { return !(*this == o); }
proto::VarType::Type data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;
int customized_type_value_;
}; };
inline std::ostream& operator<<(std::ostream& os, inline std::ostream& operator<<(std::ostream& os,

@ -35,6 +35,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Registrar { class Registrar {
public: public:
// In our design, various kinds of classes, e.g., operators and kernels, // In our design, various kinds of classes, e.g., operators and kernels,
@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor;
template <typename PlaceType, typename T, typename Func> template <typename PlaceType, typename T, typename Func>
inline void RegisterKernelClass(const char* op_type, const char* library_type, inline void RegisterKernelClass(const char* op_type, const char* library_type,
Func func) { int customized_type_value, Func func) {
std::string library(library_type); std::string library(library_type);
std::string data_layout = "ANYLAYOUT"; std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") { if (library == "MKLDNN") {
@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type,
} }
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
StringToDataLayout(data_layout), StringToDataLayout(data_layout),
StringToLibraryType(library_type)); StringToLibraryType(library_type), customized_type_value);
OperatorWithKernel::AllOpKernels()[op_type][key] = func; OperatorWithKernel::AllOpKernels()[op_type][key] = func;
} }
@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE = using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type; typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE; using T = typename KERNEL_TYPE::ELEMENT_TYPE;
RegisterKernelClass<PlaceType, T>( RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) { op_type, library_type, customized_type_value,
[](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx); KERNEL_TYPE().Compute(ctx);
}); });
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...> OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
template <typename PlaceType, size_t I, typename... KernelType> template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> { struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type, const char* library_type) const {} void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {}
}; };
// User can register many kernel in one place. The data type could be // User can register many kernel in one place. The data type could be
@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
template <typename PlaceType, typename... KernelType> template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar { class OpKernelRegistrar : public Registrar {
public: public:
explicit OpKernelRegistrar(const char* op_type, const char* library_type) { explicit OpKernelRegistrar(const char* op_type, const char* library_type,
int customized_type_value) {
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func; OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx;
template <typename PlaceType, typename... DataTypeAndKernelType> template <typename PlaceType, typename... DataTypeAndKernelType>
class OpKernelRegistrarEx : public Registrar { class OpKernelRegistrarEx : public Registrar {
public: public:
explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) { explicit OpKernelRegistrarEx(const char* op_type, const char* library_type,
int customized_type_value) {
OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...> OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType> template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, true, I, struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
DataTypeAndKernelType...> { DataTypeAndKernelType...> {
void operator()(const char* op_type, const char* library_type) const {} void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {}
}; };
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType> template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
typename std::tuple_element<I, typename std::tuple_element<I,
std::tuple<DataTypeAndKernelType...>>::type; std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type,
RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor()); int customized_type_value) const {
RegisterKernelClass<PlaceType, T>(op_type, library_type,
customized_type_value, Functor());
constexpr auto size = constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value; std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2, OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
DataTypeAndKernelType...> DataTypeAndKernelType...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
// clang-format off
/** /**
* check if MACRO is used in GLOBAL NAMESPACE. * check if MACRO is used in GLOBAL NAMESPACE.
*/ */
@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
*/ */
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \ #define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type, \
place_class, customized_name, \
customized_type_value, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \ __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \ "REGISTER_OP_KERNEL must be called in " \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \ "global namespace"); \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ static ::paddle::framework::OpKernelRegistrar<place_class, \
#library_type); \ __VA_ARGS__> \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \ __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ #op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \ return 0; \
} }
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( \
op_type, library_type, place_class, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \ #define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \
customized_name, \
customized_type_value, \
...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \ __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \ "REGISTER_OP_KERNEL_EX must be called in " \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \ "global namespace"); \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ static ::paddle::framework::OpKernelRegistrarEx<place_class, \
#library_type); \ __VA_ARGS__> \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \ __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ #op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \ return 0; \
} }
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \ REGISTER_OP_KERNEL_EX( \
op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__) __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL_EX( \
op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
/** /**
* Macro to mark what Operator and Kernel * Macro to mark what Operator and Kernel
@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
extern int TouchOpRegistrar_##op_type(); \ extern int TouchOpRegistrar_##op_type(); \
UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type() UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \ #define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
LIBRARY_TYPE, \
customized_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##LIBRARY_TYPE##__, \ __use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##__, \
"USE_OP_DEVICE_KERNEL must be in global namespace"); \ "USE_OP_DEVICE_KERNEL must be in global namespace"); \
extern int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE(); \ extern int \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_ = \ TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE() UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##DEFAULT_TYPE##_ = /* NOLINT */ \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
// TODO(fengjiayi): The following macros // TODO(fengjiayi): The following macros
// seems ugly, do we have better method? // seems ugly, do we have better method?
@ -280,6 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type) USE_OP_KERNEL(op_type)
// clang-format off
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -50,6 +50,8 @@ class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
AddInput("input", "input of test op"); AddInput("input", "input of test op");
AddOutput("output", "output of test op"); AddOutput("output", "output of test op");
AddAttr<float>("scale", "scale of cosine op"); AddAttr<float>("scale", "scale of cosine op");
AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
.SetDefault(0);
AddComment("This is test op"); AddComment("This is test op");
} }
}; };
@ -95,6 +97,8 @@ TEST(OperatorBase, all) {
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static int special_type_value = 1;
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
@ -103,11 +107,14 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.GreaterThan(0.0); .GreaterThan(0.0);
AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
.SetDefault(0);
AddComment("This is test op"); AddComment("This is test op");
} }
}; };
static int cpu_kernel_run_num = 0; static int cpu_kernel_run_num = 0;
static int cpu_kernel2_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
public: public:
@ -117,7 +124,10 @@ class OpWithKernelTest : public OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType( OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override { const ExecutionContext& ctx) const override {
return OpKernelType(proto::VarType::FP32, ctx.GetPlace()); int sub_type = ctx.Attr<int>("kernel_sub_type");
return OpKernelType(proto::VarType::FP32, ctx.GetPlace(),
framework::DataLayout::kAnyLayout,
framework::LibraryType::kPlain, sub_type);
} }
}; };
@ -132,6 +142,17 @@ class CPUKernelTest : public OpKernel<float> {
} }
}; };
template <typename T1, typename T2>
class CPUKernel2Test : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << ctx.op().DebugString() << std::endl;
cpu_kernel2_run_num++;
ASSERT_EQ(ctx.op().Input("x"), "IN1");
ASSERT_EQ(ctx.op().Output("y"), "OUT1");
}
};
class OpKernelTestMultiInputsProtoAndCheckerMaker class OpKernelTestMultiInputsProtoAndCheckerMaker
: public OpProtoAndCheckerMaker { : public OpProtoAndCheckerMaker {
public: public:
@ -142,6 +163,8 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
AddAttr<float>("scale", "scale of cosine op") AddAttr<float>("scale", "scale of cosine op")
.SetDefault(1.0) .SetDefault(1.0)
.GreaterThan(0.0); .GreaterThan(0.0);
AddAttr<int>("kernel_sub_type", "kernels with different implementations.")
.SetDefault(0);
AddComment("This is test op"); AddComment("This is test op");
} }
}; };
@ -189,9 +212,15 @@ class CPUKernalMultiInputsTest : public OpKernel<float> {
REGISTER_OP_WITHOUT_GRADIENT( REGISTER_OP_WITHOUT_GRADIENT(
op_with_kernel, paddle::framework::OpWithKernelTest, op_with_kernel, paddle::framework::OpWithKernelTest,
paddle::framework::OpKernelTestProtoAndCheckerMaker); paddle::framework::OpKernelTestProtoAndCheckerMaker);
REGISTER_OP_CPU_KERNEL(op_with_kernel, REGISTER_OP_CPU_KERNEL(op_with_kernel,
paddle::framework::CPUKernelTest<float, float>); paddle::framework::CPUKernelTest<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
op_with_kernel, CPU, paddle::platform::CPUPlace, MY_SPECIAL_NAME,
paddle::framework::special_type_value,
paddle::framework::CPUKernel2Test<float, float>);
// test with single input // test with single input
TEST(OpKernel, all) { TEST(OpKernel, all) {
paddle::framework::InitDevices(true); paddle::framework::InitDevices(true);
@ -211,7 +240,19 @@ TEST(OpKernel, all) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
// kerne_sub_type = 0, hence cpu_kernel is called, cpu_kernel2 is not called.
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 0);
attr = op_desc.mutable_attrs()->Add();
attr->set_name("kernel_sub_type");
attr->set_type(paddle::framework::proto::AttrType::INT);
attr->set_i(1);
auto op2 = paddle::framework::OpRegistry::CreateOp(op_desc);
op2->Run(scope, cpu_place);
// kerne_sub_type = 1, hence cpu_kernel2 is called, cpu_kernel is not called.
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
ASSERT_EQ(paddle::framework::cpu_kernel2_run_num, 1);
} }
REGISTER_OP_WITHOUT_GRADIENT( REGISTER_OP_WITHOUT_GRADIENT(

@ -103,6 +103,7 @@ struct Argument {
// Model specified with program and parameters files. // Model specified with program and parameters files.
DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string); DECL_ARGUMENT_FIELD(model_program_path, ModelProgramPath, std::string);
DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string); DECL_ARGUMENT_FIELD(model_params_path, ModelParamsPath, std::string);
DECL_ARGUMENT_FIELD(model_from_memory, ModelFromMemory, bool);
// The overall graph to work on. // The overall graph to work on.
DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph); DECL_ARGUMENT_UNIQUE_FIELD(main_graph, MainGraph, framework::ir::Graph);
@ -115,6 +116,10 @@ struct Argument {
DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses, DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses,
std::vector<std::string>); std::vector<std::string>);
// Pass a set of op types to enable its mkldnn kernel
DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool); DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);

@ -63,6 +63,11 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
pass_num++; pass_num++;
} }
if (pass_name == "mkldnn_placement_pass") {
pass->Set("mkldnn_enabled_op_types",
new std::unordered_set<std::string>(
argument->mkldnn_enabled_op_types()));
}
if (pass_name == "tensorrt_subgraph_pass") { if (pass_name == "tensorrt_subgraph_pass") {
PADDLE_ENFORCE(argument->tensorrt_node_teller_valid()); PADDLE_ENFORCE(argument->tensorrt_node_teller_valid());

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save