Merge branch 'develop' into develop_4f71a6ee2_conv3d_mkldnn_opt

revert-14786-revert-14782-revert-14398-imperative
Yihua Xu 6 years ago committed by GitHub
commit 65dbc7cca4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -32,6 +32,8 @@ IF(NOT ${WITH_NGRAPH})
return() return()
ENDIF() ENDIF()
INCLUDE(GNUInstallDirs)
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(NGRAPH_PROJECT "extern_ngraph") SET(NGRAPH_PROJECT "extern_ngraph")
@ -40,10 +42,14 @@ SET(NGRAPH_GIT_TAG "f9fd9d4cc318dc59dd4b68448e7fbb5f67a28bd0")
SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph) SET(NGRAPH_SOURCES_DIR ${THIRD_PARTY_PATH}/ngraph)
SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph) SET(NGRAPH_INSTALL_DIR ${THIRD_PARTY_PATH}/install/ngraph)
SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include) SET(NGRAPH_INC_DIR ${NGRAPH_INSTALL_DIR}/include)
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
SET(NGRAPH_SHARED_LIB_NAME libngraph.so.${NGRAPH_VERSION}) SET(NGRAPH_SHARED_LIB_NAME libngraph.so.${NGRAPH_VERSION})
SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so) SET(NGRAPH_CPU_LIB_NAME libcpu_backend.so)
SET(NGRAPH_TBB_LIB_NAME libtbb.so.2) SET(NGRAPH_TBB_LIB_NAME libtbb.so.2)
SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git") SET(NGRAPH_GIT_REPO "https://github.com/NervanaSystems/ngraph.git")
SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME})
SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME})
SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
ExternalProject_Add( ExternalProject_Add(
${NGRAPH_PROJECT} ${NGRAPH_PROJECT}
@ -63,18 +69,6 @@ ExternalProject_Add(
CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/lib CMAKE_ARGS -DMKLDNN_LIB_DIR=${MKLDNN_INSTALL_DIR}/lib
) )
if(UNIX AND NOT APPLE)
include(GNUInstallDirs)
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR})
else()
SET(NGRAPH_LIB_DIR ${NGRAPH_INSTALL_DIR}/lib)
endif()
MESSAGE(STATUS "nGraph lib will be installed at: ${NGRAPH_LIB_DIR}")
SET(NGRAPH_SHARED_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_SHARED_LIB_NAME})
SET(NGRAPH_CPU_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_CPU_LIB_NAME})
SET(NGRAPH_TBB_LIB ${NGRAPH_LIB_DIR}/${NGRAPH_TBB_LIB_NAME})
# Workaround for nGraph expecting mklml to be in mkldnn install directory. # Workaround for nGraph expecting mklml to be in mkldnn install directory.
ExternalProject_Add_Step( ExternalProject_Add_Step(
${NGRAPH_PROJECT} ${NGRAPH_PROJECT}

@ -129,6 +129,15 @@ if (WITH_MKLDNN)
) )
endif () endif ()
if (WITH_NGRAPH)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/ngraph")
copy(ngraph_lib
SRCS ${NGRAPH_INC_DIR} ${NGRAPH_LIB_DIR}
DSTS ${dst_dir} ${dst_dir}
DEPS ngraph
)
endif ()
if (NOT WIN32) if (NOT WIN32)
if (NOT MOBILE_INFERENCE AND NOT RPI) if (NOT MOBILE_INFERENCE AND NOT RPI)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy")

@ -166,6 +166,8 @@ function(op_library TARGET)
# Append first implemented MKLDNN activation operator # Append first implemented MKLDNN activation operator
if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
elseif(${MKLDNN_FILE} STREQUAL "conv_mkldnn_op")
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN, FP32);\n")
else() else()
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
endif() endif()

@ -182,7 +182,7 @@ paddle.fluid.layers.clip ArgSpec(args=['x', 'min', 'max', 'name'], varargs=None,
paddle.fluid.layers.clip_by_norm ArgSpec(args=['x', 'max_norm', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.clip_by_norm ArgSpec(args=['x', 'max_norm', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'ignore_index', 'name'], varargs=None, keywords=None, defaults=(-100, None))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
@ -194,6 +194,8 @@ paddle.fluid.layers.grid_sampler ArgSpec(args=['x', 'grid', 'name'], varargs=Non
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None)) paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act', 'name', 'param_attr', 'bias_attr'], varargs=None, keywords=None, defaults=(None, None, None, None))
paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
@ -299,6 +301,7 @@ paddle.fluid.layers.generate_proposals ArgSpec(args=['scores', 'bbox_deltas', 'i
paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.iou_similarity ArgSpec(args=['x', 'y', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None)) paddle.fluid.layers.box_coder ArgSpec(args=['prior_box', 'prior_box_var', 'target_box', 'code_type', 'box_normalized', 'name'], varargs=None, keywords=None, defaults=('encode_center_size', True, None))
paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.polygon_box_transform ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.yolov3_loss ArgSpec(args=['x', 'gtbox', 'gtlabel', 'anchors', 'class_num', 'ignore_thresh', 'loss_weight_xy', 'loss_weight_wh', 'loss_weight_conf_target', 'loss_weight_conf_notarget', 'loss_weight_class', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, None, None))
paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None)) paddle.fluid.layers.accuracy ArgSpec(args=['input', 'label', 'k', 'correct', 'total'], varargs=None, keywords=None, defaults=(1, None, None))
paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1)) paddle.fluid.layers.auc ArgSpec(args=['input', 'label', 'curve', 'num_thresholds', 'topk', 'slide_steps'], varargs=None, keywords=None, defaults=('ROC', 4095, 1, 1))
paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.layers.exponential_decay ArgSpec(args=['learning_rate', 'decay_steps', 'decay_rate', 'staircase'], varargs=None, keywords=None, defaults=(False,))
@ -419,3 +422,17 @@ paddle.fluid.Scope.drop_kids drop_kids(self: paddle.fluid.core.Scope) -> None
paddle.fluid.Scope.find_var find_var(self: paddle.fluid.core.Scope, arg0: unicode) -> paddle.fluid.core.Variable paddle.fluid.Scope.find_var find_var(self: paddle.fluid.core.Scope, arg0: unicode) -> paddle.fluid.core.Variable
paddle.fluid.Scope.new_scope new_scope(self: paddle.fluid.core.Scope) -> paddle.fluid.core.Scope paddle.fluid.Scope.new_scope new_scope(self: paddle.fluid.core.Scope) -> paddle.fluid.core.Scope
paddle.fluid.Scope.var var(self: paddle.fluid.core.Scope, arg0: unicode) -> paddle.fluid.core.Variable paddle.fluid.Scope.var var(self: paddle.fluid.core.Scope, arg0: unicode) -> paddle.fluid.core.Variable
paddle.reader.map_readers ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None)
paddle.reader.buffered ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None)
paddle.reader.compose ArgSpec(args=[], varargs='readers', keywords='kwargs', defaults=None)
paddle.reader.chain ArgSpec(args=[], varargs='readers', keywords=None, defaults=None)
paddle.reader.shuffle ArgSpec(args=['reader', 'buf_size'], varargs=None, keywords=None, defaults=None)
paddle.reader.firstn ArgSpec(args=['reader', 'n'], varargs=None, keywords=None, defaults=None)
paddle.reader.xmap_readers ArgSpec(args=['mapper', 'reader', 'process_num', 'buffer_size', 'order'], varargs=None, keywords=None, defaults=(False,))
paddle.reader.PipeReader.__init__ ArgSpec(args=['self', 'command', 'bufsize', 'file_type'], varargs=None, keywords=None, defaults=(8192, 'plain'))
paddle.reader.PipeReader.get_line ArgSpec(args=['self', 'cut_lines', 'line_break'], varargs=None, keywords=None, defaults=(True, '\n'))
paddle.reader.multiprocess_reader ArgSpec(args=['readers', 'use_pipe', 'queue_size'], varargs=None, keywords=None, defaults=(True, 1000))
paddle.reader.Fake.__init__ ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
paddle.reader.creator.np_array ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.reader.creator.text_file ArgSpec(args=['path'], varargs=None, keywords=None, defaults=None)
paddle.reader.creator.recordio ArgSpec(args=['paths', 'buf_size'], varargs=None, keywords=None, defaults=(100,))

@ -118,8 +118,9 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)
cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context) cc_library(transfer_scope_cache SRCS transfer_scope_cache.cc DEPS scope framework_proto device_context)
cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler transfer_scope_cache) shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
@ -127,8 +128,9 @@ cc_library(version SRCS version.cc)
cc_test(version_test SRCS version_test.cc DEPS version) cc_test(version_test SRCS version_test.cc DEPS version)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto)
if(NOT WIN32) if(NOT WIN32)
cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph)
cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler) shape_inference data_transform lod_tensor profiler)
endif(NOT WIN32) endif(NOT WIN32)
@ -190,7 +192,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto op_kernel_type)
cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc) cc_test(cow_ptr_tests SRCS details/cow_ptr_test.cc)
cc_test(tuple_test SRCS tuple_test.cc ) cc_test(tuple_test SRCS tuple_test.cc )

@ -48,7 +48,14 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node,
void AllReduceOpHandle::RunImpl() { void AllReduceOpHandle::RunImpl() {
platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second); platform::RecordEvent record_event(Name(), dev_ctxes_.cbegin()->second);
// FIXME(typhoonzero): If scope0(global scope) have NCCL_ID_VAR,
// this is a distributed or inter-process call, find a better way.
#ifdef PADDLE_WITH_CUDA
if (NoDummyInputSize() == 1 &&
local_scopes_[0]->FindLocalVar(NCCL_ID_VARNAME) == nullptr) {
#else
if (NoDummyInputSize() == 1) { if (NoDummyInputSize() == 1) {
#endif
return; // No need to all reduce when GPU count = 1; return; // No need to all reduce when GPU count = 1;
} else { } else {
// Wait input done // Wait input done

@ -62,6 +62,8 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
auto multi_devices_pass = AppendPass("multi_devices_pass"); auto multi_devices_pass = AppendPass("multi_devices_pass");
multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy", multi_devices_pass->SetNotOwned<const BuildStrategy>("strategy",
&strategy_); &strategy_);
multi_devices_pass->Set<int>("num_trainers",
new int(strategy_.num_trainers_));
// Add a graph print pass to record a graph with device info. // Add a graph print pass to record a graph with device info.
if (!strategy_.debug_graphviz_path_.empty()) { if (!strategy_.debug_graphviz_path_.empty()) {

@ -133,6 +133,7 @@ static const char kPlaces[] = "places";
static const char kParams[] = "params"; static const char kParams[] = "params";
static const char kLocalScopes[] = "local_scopes"; static const char kLocalScopes[] = "local_scopes";
static const char kStrategy[] = "strategy"; static const char kStrategy[] = "strategy";
static const char kNumTrainers[] = "num_trainers";
void MultiDevSSAGraphBuilder::Init() const { void MultiDevSSAGraphBuilder::Init() const {
all_vars_.clear(); all_vars_.clear();
@ -299,6 +300,8 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph; ir::Graph &result = *graph;
int num_trainers = Get<int>(kNumTrainers);
for (auto &node : nodes) { for (auto &node : nodes) {
if (node->IsVar() && node->Var()) { if (node->IsVar() && node->Var()) {
all_vars_.emplace(node->Name(), node->Var()); all_vars_.emplace(node->Name(), node->Var());
@ -383,7 +386,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateComputationalOps(&result, node, places_.size()); CreateComputationalOps(&result, node, places_.size());
} }
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr( if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
@ -895,4 +898,5 @@ REGISTER_PASS(multi_devices_pass,
.RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kPlaces)
.RequirePassAttr(paddle::framework::details::kParams) .RequirePassAttr(paddle::framework::details::kParams)
.RequirePassAttr(paddle::framework::details::kLocalScopes) .RequirePassAttr(paddle::framework::details::kLocalScopes)
.RequirePassAttr(paddle::framework::details::kStrategy); .RequirePassAttr(paddle::framework::details::kStrategy)
.RequirePassAttr(paddle::framework::details::kNumTrainers);

@ -32,9 +32,7 @@ enum OpInfoFillType {
kOpProtoAndCheckerMaker = 1, kOpProtoAndCheckerMaker = 1,
kGradOpDescMaker = 2, kGradOpDescMaker = 2,
kVarTypeInference = 3, kVarTypeInference = 3,
kShapeInference = 4, kShapeInference = 4
kEstimateFlops = 5,
kUnknown = -1
}; };
template <typename T> template <typename T>
@ -50,10 +48,8 @@ struct OpInfoFillTypeID {
? kVarTypeInference ? kVarTypeInference
: (std::is_base_of<InferShapeBase, T>::value : (std::is_base_of<InferShapeBase, T>::value
? kShapeInference ? kShapeInference
: (std::is_base_of<EstimateFlopsBase, : static_cast<OpInfoFillType>(
T>::value -1)))));
? kEstimateFlops
: kUnknown)))));
} }
}; };
@ -143,16 +139,6 @@ struct OpInfoFiller<T, kShapeInference> {
} }
}; };
template <typename T>
struct OpInfoFiller<T, kEstimateFlops> {
void operator()(const char* op_tpe, OpInfo* info) const {
info->estimate_flops_ = [](InferShapeContext* ctx) {
T estimate_flops;
return estimate_flops(ctx);
};
}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework

@ -177,14 +177,13 @@ class Graph {
return nullptr; return nullptr;
} }
const ProgramDesc &program() const { return program_; }
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
void ResolveHazard( void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes); const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
private: private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
// This method takes ownership of `node`. // This method takes ownership of `node`.
ir::Node *AddNode(ir::Node *node) { ir::Node *AddNode(ir::Node *node) {
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end()); PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());

@ -38,7 +38,7 @@ std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp()) { if (n->IsOp()) {
auto* op = n->Op(); auto* op = n->Op();
if (op->HasAttr("is_test")) { if (n->RuntimeHasAttr("is_test")) {
op->SetAttr("is_test", true); op->SetAttr("is_test", true);
} else if (std::find(begin(op_list), end(op_list), op->Type()) != } else if (std::find(begin(op_list), end(op_list), op->Type()) !=
end(op_list)) { end(op_list)) {

@ -104,9 +104,9 @@ TEST(IsTestPass, basic) {
auto* op = node->Op(); auto* op = node->Op();
auto op_name = boost::get<std::string>(op->GetAttr("name")); auto op_name = boost::get<std::string>(op->GetAttr("name"));
if (op_name == "conv3") { if (op_name == "conv3") {
ASSERT_FALSE(op->HasAttr("is_test")); ASSERT_FALSE(node->RuntimeHasAttr("is_test"));
} else { } else {
ASSERT_TRUE(op->HasAttr("is_test")); ASSERT_TRUE(node->RuntimeHasAttr("is_test"));
EXPECT_TRUE(boost::get<bool>(op->GetAttr("is_test"))); EXPECT_TRUE(boost::get<bool>(op->GetAttr("is_test")));
} }
} }

@ -22,7 +22,7 @@ std::unique_ptr<ir::Graph> MKLDNNPlacementPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
VLOG(3) << "Aplies MKL-DNN placement strategy."; VLOG(3) << "Aplies MKL-DNN placement strategy.";
for (const Node* n : graph->Nodes()) { for (const Node* n : graph->Nodes()) {
if (n->IsOp() && n->Op()->HasAttr("use_mkldnn")) { if (n->IsOp() && n->RuntimeHasAttr("use_mkldnn")) {
n->Op()->SetAttr("use_mkldnn", true); n->Op()->SetAttr("use_mkldnn", true);
} }
} }

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_info.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
@ -24,10 +25,33 @@ constexpr char Node::kControlDepVarName[];
const char Node::kControlDepVarName[] = "__control_var"; const char Node::kControlDepVarName[] = "__control_var";
#endif #endif
std::unique_ptr<Node> CreateNodeForTest(const std::string& name, std::unique_ptr<Node> CreateNodeForTest(const std::string &name,
Node::Type type) { Node::Type type) {
return std::unique_ptr<Node>(new Node(name, type)); return std::unique_ptr<Node>(new Node(name, type));
} }
bool Node::RuntimeHasAttr(const std::string &name) const {
if (Op()->HasAttr(name)) {
return true;
} else {
auto &op_info = OpInfoMap::Instance();
auto op_type = Op()->Type();
if (op_info.Has(op_type)) {
auto op_info_ptr = op_info.Get(op_type);
if (op_info_ptr.HasOpProtoAndChecker()) {
const proto::OpProto &proto = op_info_ptr.Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return true;
}
}
}
}
}
return false;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -108,6 +108,18 @@ class Node {
Name().find(ir::Node::kControlDepVarName) != std::string::npos; Name().find(ir::Node::kControlDepVarName) != std::string::npos;
} }
// RuntimeHasAttr is different with HasAttr now.
// 1. For Op()->HasAttr(), it judges whether a stored program_desc_ has attr,
// thus, if stored program_desc_ are old which don't have an attr, a new
// library which adds the attr already will fail on this function.
// Details:
// https://github.com/PaddlePaddle/Paddle/pull/14608#issuecomment-442309087
// 2. For Op()->RuntimeHasAttr, it judges the attr in runtime to avoid above
// problem.
// TODO(luotao): Maybe we should enhance HasAttr later, instead of adding
// RuntimeHasAttr.
bool RuntimeHasAttr(const std::string& name) const;
std::vector<Node*> inputs; std::vector<Node*> inputs;
std::vector<Node*> outputs; std::vector<Node*> outputs;

@ -15,23 +15,105 @@ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
#include <vector>
#include "paddle/fluid/framework/ngraph_bridge.h" #include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "ngraph/ngraph.hpp" #include "ngraph/ngraph.hpp"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static std::shared_ptr<ngraph::Node> GetNode(
const std::shared_ptr<OperatorBase>& op, const std::string prm,
const VariableNameMap& var_map,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = var_map.at(prm);
PADDLE_ENFORCE_EQ(var_names.size(), 1,
"op %s prm %s expects one associated var", op->Type(), prm);
if (ngb_node_map->find(var_names[0]) != ngb_node_map->end()) {
return (*ngb_node_map)[var_names[0]];
} else {
return nullptr;
}
}
static std::shared_ptr<ngraph::Node> GetInputNode(
const std::shared_ptr<OperatorBase>& op, const std::string prm,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, prm, op->Inputs(), ngb_node_map);
}
static std::shared_ptr<ngraph::Node> GetOutputNode(
const std::shared_ptr<OperatorBase>& op, const std::string prm,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
return GetNode(op, prm, op->Outputs(), ngb_node_map);
}
static void SetOutputNode(
const std::shared_ptr<OperatorBase>& op, const std::string prm,
std::shared_ptr<ngraph::Node> node,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto& var_names = op->Outputs().at(prm);
if (var_names.size() == 1) {
(*ngb_node_map)[var_names[0]] = node;
} else if (var_names.size() == 0) {
(*ngb_node_map)[""] = node;
} else {
PADDLE_THROW("prm %s has more than 1 var_names.", prm);
}
}
static bool HasOutput(const std::shared_ptr<OperatorBase>& op,
const std::string prm) {
auto& outputs = op->Outputs();
if (outputs.find(prm) == outputs.end()) return false;
return outputs.at(prm).size() > 0;
}
template <typename T>
static void BuildBinaryNode(
const std::shared_ptr<OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = GetInputNode(op, "X", ngb_node_map);
auto y = GetInputNode(op, "Y", ngb_node_map);
auto out = std::make_shared<T>(x, y);
SetOutputNode(op, "Out", out, ngb_node_map);
}
template <typename T>
static void BuildUnaryNode(
const std::shared_ptr<OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = GetInputNode(op, "X", ngb_node_map);
auto out = std::make_shared<T>(input);
SetOutputNode(op, "Out", out, ngb_node_map);
}
std::map<std::string, std::map<std::string,
std::function<void(const std::shared_ptr<OperatorBase>&, std::function<void(const std::shared_ptr<OperatorBase>&,
std::shared_ptr<std::unordered_map< std::shared_ptr<std::unordered_map<
std::string, std::shared_ptr<ngraph::Node>>>)>> std::string, std::shared_ptr<ngraph::Node>>>)>>
NgraphBridge::NG_NODE_MAP = {}; NgraphBridge::NG_NODE_MAP = {{"relu", BuildUnaryNode<ngraph::op::Relu>},
{"tanh", BuildUnaryNode<ngraph::op::Tanh>}};
void NgraphBridge::build_graph(const std::shared_ptr<OperatorBase>& op) { void NgraphBridge::BuildNgNode(const std::shared_ptr<OperatorBase>& op) {
auto& op_type = op->Type(); auto& op_type = op->Type();
NG_NODE_MAP[op_type](op, ngb_node_map); NG_NODE_MAP[op_type](op, ngb_node_map_);
} }
} // namespace framework } // namespace framework

@ -20,16 +20,14 @@ limitations under the License. */
#include <map> #include <map>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/operator.h" #include "ngraph/node.hpp"
#include "paddle/fluid/platform/enforce.h"
#include "ngraph/ngraph.hpp"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase;
class NgraphBridge { class NgraphBridge {
public: public:
static std::map< static std::map<
@ -43,14 +41,14 @@ class NgraphBridge {
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
var_node_map) var_node_map)
: ngb_node_map(var_node_map) {} : ngb_node_map_(var_node_map) {}
void build_graph(const std::shared_ptr<OperatorBase>& op); void BuildNgNode(const std::shared_ptr<OperatorBase>& op);
private: private:
std::shared_ptr< std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>> std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map; ngb_node_map_;
}; };
} // namespace framework } // namespace framework

File diff suppressed because it is too large Load Diff

@ -17,24 +17,19 @@ limitations under the License. */
#ifdef PADDLE_WITH_NGRAPH #ifdef PADDLE_WITH_NGRAPH
#include <algorithm> #include <algorithm>
#include <atomic>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/ngraph_bridge.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
#include "ngraph/ngraph.hpp" #include "ngraph/type/element_type.hpp"
namespace paddle { namespace paddle {
namespace framework { namespace framework {

@ -31,12 +31,6 @@ class InferShapeBase {
virtual void operator()(InferShapeContext*) const = 0; virtual void operator()(InferShapeContext*) const = 0;
}; };
class EstimateFlopsBase {
public:
virtual ~EstimateFlopsBase() = default;
virtual size_t operator()(InferShapeContext*) const = 0;
};
struct OpInfo { struct OpInfo {
OpCreator creator_; OpCreator creator_;
GradOpMakerFN grad_op_maker_; GradOpMakerFN grad_op_maker_;
@ -44,7 +38,6 @@ struct OpInfo {
OpAttrChecker* checker_{nullptr}; OpAttrChecker* checker_{nullptr};
InferVarTypeFN infer_var_type_; InferVarTypeFN infer_var_type_;
InferShapeFN infer_shape_; InferShapeFN infer_shape_;
EstimateFlopsFN estimate_flops_;
bool HasOpProtoAndChecker() const { bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr; return proto_ != nullptr && checker_ != nullptr;

@ -0,0 +1,54 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_kernel_type.h"
namespace paddle {
namespace framework {
size_t OpKernelType::Hash::operator()(const OpKernelType& key) const {
int cur_loc = 0;
int place = key.place_.which();
cur_loc += OpKernelType::kPlaceBits;
int data_type = static_cast<int>(key.data_type_) << cur_loc;
cur_loc += OpKernelType::kPrimaryDTypeBits;
int data_layout = static_cast<int>(key.data_layout_) << cur_loc;
cur_loc += OpKernelType::kLayoutBits;
int library_type = static_cast<int>(key.library_type_) << cur_loc;
cur_loc += OpKernelType::kLibBits;
int customized_value = key.customized_type_value_;
PADDLE_ENFORCE(customized_value < (1 << OpKernelType::kCustomizeBits));
customized_value = customized_value << cur_loc;
cur_loc += OpKernelType::kCustomizeBits;
PADDLE_ENFORCE(cur_loc < 64);
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type +
customized_value);
}
bool OpKernelType::operator==(const OpKernelType& o) const {
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_ &&
customized_type_value_ == o.customized_type_value_;
}
} // namespace framework
} // namespace paddle

@ -24,54 +24,55 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
struct OpKernelType { class OpKernelType {
struct Hash { public:
size_t operator()(const OpKernelType& key) const { constexpr static int kDefaultCustomizedTypeValue = 0;
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT;
int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2);
int library_type = static_cast<int>(key.library_type_)
<< (LEFT_SHIFT * 3);
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type);
}
};
// place, data_type, library_type kinds less than 2^8 // In total should be smaller than 64.
constexpr static int LEFT_SHIFT = 8; constexpr static int kPlaceBits = 4;
constexpr static int kPrimaryDTypeBits = 8;
proto::VarType::Type data_type_; constexpr static int kLayoutBits = 4;
DataLayout data_layout_; constexpr static int kLibBits = 4;
platform::Place place_; constexpr static int kCustomizeBits = 4;
LibraryType library_type_;
OpKernelType(proto::VarType::Type data_type, platform::Place place, OpKernelType(proto::VarType::Type data_type, platform::Place place,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain,
int customized_type_value = kDefaultCustomizedTypeValue)
: data_type_(data_type), : data_type_(data_type),
data_layout_(data_layout), data_layout_(data_layout),
place_(place), place_(place),
library_type_(library_type) {} library_type_(library_type),
customized_type_value_(customized_type_value) {}
OpKernelType(proto::VarType::Type data_type, OpKernelType(proto::VarType::Type data_type,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx,
DataLayout data_layout = DataLayout::kAnyLayout, DataLayout data_layout = DataLayout::kAnyLayout,
LibraryType library_type = LibraryType::kPlain) LibraryType library_type = LibraryType::kPlain,
int customized_type_value = kDefaultCustomizedTypeValue)
: data_type_(data_type), : data_type_(data_type),
data_layout_(data_layout), data_layout_(data_layout),
place_(dev_ctx.GetPlace()), place_(dev_ctx.GetPlace()),
library_type_(library_type) {} library_type_(library_type),
customized_type_value_(customized_type_value) {}
virtual ~OpKernelType() {}
struct Hash {
size_t operator()(const OpKernelType& key) const;
};
size_t hash_key() const { return Hash()(*this); } size_t hash_key() const { return Hash()(*this); }
bool operator==(const OpKernelType& o) const { bool operator==(const OpKernelType& o) const;
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_ && data_layout_ == o.data_layout_ &&
library_type_ == o.library_type_;
}
bool operator!=(const OpKernelType& o) const { return !(*this == o); } bool operator!=(const OpKernelType& o) const { return !(*this == o); }
proto::VarType::Type data_type_;
DataLayout data_layout_;
platform::Place place_;
LibraryType library_type_;
int customized_type_value_;
}; };
inline std::ostream& operator<<(std::ostream& os, inline std::ostream& operator<<(std::ostream& os,

@ -35,6 +35,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Registrar { class Registrar {
public: public:
// In our design, various kinds of classes, e.g., operators and kernels, // In our design, various kinds of classes, e.g., operators and kernels,
@ -78,7 +79,7 @@ struct OpKernelRegistrarFunctor;
template <typename PlaceType, typename T, typename Func> template <typename PlaceType, typename T, typename Func>
inline void RegisterKernelClass(const char* op_type, const char* library_type, inline void RegisterKernelClass(const char* op_type, const char* library_type,
Func func) { int customized_type_value, Func func) {
std::string library(library_type); std::string library(library_type);
std::string data_layout = "ANYLAYOUT"; std::string data_layout = "ANYLAYOUT";
if (library == "MKLDNN") { if (library == "MKLDNN") {
@ -86,7 +87,7 @@ inline void RegisterKernelClass(const char* op_type, const char* library_type,
} }
OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(),
StringToDataLayout(data_layout), StringToDataLayout(data_layout),
StringToLibraryType(library_type)); StringToLibraryType(library_type), customized_type_value);
OperatorWithKernel::AllOpKernels()[op_type][key] = func; OperatorWithKernel::AllOpKernels()[op_type][key] = func;
} }
@ -95,22 +96,26 @@ struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE = using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type; typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE; using T = typename KERNEL_TYPE::ELEMENT_TYPE;
RegisterKernelClass<PlaceType, T>( RegisterKernelClass<PlaceType, T>(
op_type, library_type, [](const framework::ExecutionContext& ctx) { op_type, library_type, customized_type_value,
[](const framework::ExecutionContext& ctx) {
KERNEL_TYPE().Compute(ctx); KERNEL_TYPE().Compute(ctx);
}); });
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value; constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...> OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
template <typename PlaceType, size_t I, typename... KernelType> template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> { struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type, const char* library_type) const {} void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {}
}; };
// User can register many kernel in one place. The data type could be // User can register many kernel in one place. The data type could be
@ -118,9 +123,10 @@ struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
template <typename PlaceType, typename... KernelType> template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar { class OpKernelRegistrar : public Registrar {
public: public:
explicit OpKernelRegistrar(const char* op_type, const char* library_type) { explicit OpKernelRegistrar(const char* op_type, const char* library_type,
int customized_type_value) {
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func; OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
@ -130,17 +136,19 @@ struct OpKernelRegistrarFunctorEx;
template <typename PlaceType, typename... DataTypeAndKernelType> template <typename PlaceType, typename... DataTypeAndKernelType>
class OpKernelRegistrarEx : public Registrar { class OpKernelRegistrarEx : public Registrar {
public: public:
explicit OpKernelRegistrarEx(const char* op_type, const char* library_type) { explicit OpKernelRegistrarEx(const char* op_type, const char* library_type,
int customized_type_value) {
OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...> OpKernelRegistrarFunctorEx<PlaceType, false, 0, DataTypeAndKernelType...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType> template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
struct OpKernelRegistrarFunctorEx<PlaceType, true, I, struct OpKernelRegistrarFunctorEx<PlaceType, true, I,
DataTypeAndKernelType...> { DataTypeAndKernelType...> {
void operator()(const char* op_type, const char* library_type) const {} void operator()(const char* op_type, const char* library_type,
int customized_type_value) const {}
}; };
template <typename PlaceType, size_t I, typename... DataTypeAndKernelType> template <typename PlaceType, size_t I, typename... DataTypeAndKernelType>
@ -153,18 +161,21 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
typename std::tuple_element<I, typename std::tuple_element<I,
std::tuple<DataTypeAndKernelType...>>::type; std::tuple<DataTypeAndKernelType...>>::type;
void operator()(const char* op_type, const char* library_type) const { void operator()(const char* op_type, const char* library_type,
RegisterKernelClass<PlaceType, T>(op_type, library_type, Functor()); int customized_type_value) const {
RegisterKernelClass<PlaceType, T>(op_type, library_type,
customized_type_value, Functor());
constexpr auto size = constexpr auto size =
std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value; std::tuple_size<std::tuple<DataTypeAndKernelType...>>::value;
OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2, OpKernelRegistrarFunctorEx<PlaceType, I + 2 >= size, I + 2,
DataTypeAndKernelType...> DataTypeAndKernelType...>
func; func;
func(op_type, library_type); func(op_type, library_type, customized_type_value);
} }
}; };
// clang-format off
/** /**
* check if MACRO is used in GLOBAL NAMESPACE. * check if MACRO is used in GLOBAL NAMESPACE.
*/ */
@ -199,42 +210,64 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
/** /**
* Macro to register OperatorKernel. * Macro to register OperatorKernel.
*/ */
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \ #define REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(op_type, library_type, \
place_class, customized_name, \
customized_type_value, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \ __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
"REGISTER_OP_KERNEL must be called in global namespace"); \ "REGISTER_OP_KERNEL must be called in " \
static ::paddle::framework::OpKernelRegistrar<place_class, __VA_ARGS__> \ "global namespace"); \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ static ::paddle::framework::OpKernelRegistrar<place_class, \
#library_type); \ __VA_ARGS__> \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \ __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ #op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \ return 0; \
} }
#define REGISTER_OP_KERNEL(op_type, library_type, place_class, ...) \
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE( \
op_type, library_type, place_class, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
#define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, ...) \ #define REGISTER_OP_KERNEL_EX(op_type, library_type, place_class, \
customized_name, \
customized_type_value, \
...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##op_type##_##library_type##__, \ __reg_op_kernel_##op_type##_##library_type##_##customized_name##__, \
"REGISTER_OP_KERNEL_EX must be called in global namespace"); \ "REGISTER_OP_KERNEL_EX must be called in " \
static ::paddle::framework::OpKernelRegistrarEx<place_class, __VA_ARGS__> \ "global namespace"); \
__op_kernel_registrar_##op_type##_##library_type##__(#op_type, \ static ::paddle::framework::OpKernelRegistrarEx<place_class, \
#library_type); \ __VA_ARGS__> \
int TouchOpKernelRegistrar_##op_type##_##library_type() { \ __op_kernel_registrar_##op_type##_##library_type##_##customized_name##__(\
__op_kernel_registrar_##op_type##_##library_type##__.Touch(); \ #op_type, #library_type, customized_type_value); \
int TouchOpKernelRegistrar_##op_type##_##library_type##_##customized_name() {\
__op_kernel_registrar_##op_type##_##library_type##_##customized_name##__ \
.Touch(); \
return 0; \ return 0; \
} }
#define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CUDA, ::paddle::platform::CUDAPlace, \ REGISTER_OP_KERNEL_EX( \
op_type, CUDA, ::paddle::platform::CUDAPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__) __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL_FUNCTOR(op_type, ...) \
REGISTER_OP_KERNEL_EX(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL_EX( \
op_type, CPU, ::paddle::platform::CPUPlace, DEFAULT_TYPE, \
::paddle::framework::OpKernelType::kDefaultCustomizedTypeValue, \
__VA_ARGS__)
/** /**
* Macro to mark what Operator and Kernel * Macro to mark what Operator and Kernel
@ -248,13 +281,19 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
extern int TouchOpRegistrar_##op_type(); \ extern int TouchOpRegistrar_##op_type(); \
UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type() UNUSED static int use_op_itself_##op_type##_ = TouchOpRegistrar_##op_type()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \ #define USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, \
LIBRARY_TYPE, \
customized_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_op_kernel_##op_type##_##LIBRARY_TYPE##__, \ __use_op_kernel_##op_type##_##LIBRARY_TYPE##_##customized_name##__, \
"USE_OP_DEVICE_KERNEL must be in global namespace"); \ "USE_OP_DEVICE_KERNEL must be in global namespace"); \
extern int TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE(); \ extern int \
UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_ = \ TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name(); \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE() UNUSED static int use_op_kernel_##op_type##_##LIBRARY_TYPE##_##DEFAULT_TYPE##_ = /* NOLINT */ \
TouchOpKernelRegistrar_##op_type##_##LIBRARY_TYPE##_##customized_name()
#define USE_OP_DEVICE_KERNEL(op_type, LIBRARY_TYPE) \
USE_OP_DEVICE_KERNEL_WITH_CUSTOM_TYPE(op_type, LIBRARY_TYPE, DEFAULT_TYPE)
// TODO(fengjiayi): The following macros // TODO(fengjiayi): The following macros
// seems ugly, do we have better method? // seems ugly, do we have better method?
@ -280,6 +319,7 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
#define USE_OP(op_type) \ #define USE_OP(op_type) \
USE_OP_ITSELF(op_type); \ USE_OP_ITSELF(op_type); \
USE_OP_KERNEL(op_type) USE_OP_KERNEL(op_type)
// clang-format off
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -695,6 +695,12 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name); "Tensor %s contains NAN", name);
} }
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
}
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);

@ -128,6 +128,8 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {}
protected: protected:
std::string type_; std::string type_;
@ -348,6 +350,9 @@ class OperatorWithKernel : public OperatorBase {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
} }
void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const override;
protected: protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetKernelTypeForVar( virtual OpKernelType GetKernelTypeForVar(

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save