Merge develop

test=develop
revert-16555-model_data_cryption_link_all_lib
sneaxiy 6 years ago
commit 16f0994728

@ -64,6 +64,7 @@ option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF) option(WITH_PSLIB "Compile with pslib support" OFF)
option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF) option(REPLACE_ENFORCE_GLOG "Replace PADDLE_ENFORCE with glog/CHECK for better debug." OFF)
# TODO(Superjomn) Remove WITH_ANAKIN option if not needed latter.
option(WITH_ANAKIN "Compile with Anakin library" OFF) option(WITH_ANAKIN "Compile with Anakin library" OFF)
option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF) option(ANAKIN_BUILD_FAT_BIN "Build anakin cuda fat-bin lib for all device plantform, ignored when WITH_ANAKIN=OFF" OFF)
option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON) option(ANAKIN_BUILD_CROSS_PLANTFORM "Build anakin lib for any nvidia device plantform. ignored when WITH_ANAKIN=OFF" ON)
@ -190,6 +191,7 @@ include(configure) # add paddle env configuration
if(WITH_GPU) if(WITH_GPU)
include(cuda) include(cuda)
include(tensorrt) include(tensorrt)
include(anakin_subgraph)
endif() endif()
if(WITH_MKL OR WITH_MKLML) if(WITH_MKL OR WITH_MKLML)
include(external/anakin) include(external/anakin)

@ -156,7 +156,7 @@ python \
This will enable VLOG messages generated by `buddy_allocator.{h,cc}` and in the verbose range of 0 to 3, so you will see above example VLOG message, which is in level 3. This suggests that we output overall messages in lower verbose levels, so they display with higher probability. When coding C++, please follow the verbose level convention as follows: This will enable VLOG messages generated by `buddy_allocator.{h,cc}` and in the verbose range of 0 to 3, so you will see above example VLOG message, which is in level 3. This suggests that we output overall messages in lower verbose levels, so they display with higher probability. When coding C++, please follow the verbose level convention as follows:
- verbose level 1: [framework](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/framework) - verbose level 1: [framework](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/framework)
- verbose level 3: [operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators) - verbose level 3: [operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/operators)
- verbose level 5: [memory](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/memory), [platform](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/platform) - verbose level 5: [memory](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/memory), [platform](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/platform)
- verbose level 7: [math](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/legacy/math) - verbose level 7: [math](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/fluid/operators/math/)

@ -0,0 +1,32 @@
if(NOT WITH_GPU)
return()
endif()
set(ANAKIN_ROOT "/usr" CACHE PATH "ANAKIN ROOT")
find_path(ANAKIN_INCLUDE_DIR anakin_config.h
PATHS ${ANAKIN_ROOT} ${ANAKIN_ROOT}/include
$ENV{ANAKIN_ROOT} $ENV{ANAKIN_ROOT}/include
NO_DEFAULT_PATH
)
find_library(ANAKIN_LIBRARY NAMES libanakin_saber_common.so libanakin.so
PATHS ${ANAKIN_ROOT}
$ENV{ANAKIN_ROOT} $ENV{ANAKIN_ROOT}/lib
NO_DEFAULT_PATH
DOC "Path to ANAKIN library.")
if(ANAKIN_INCLUDE_DIR AND ANAKIN_LIBRARY)
if(WITH_DSO)
set(ANAKIN_FOUND ON)
endif(WITH_DSO)
else()
set(ANAKIN_FOUND OFF)
endif()
if(ANAKIN_FOUND)
message(STATUS "Current ANAKIN header is ${ANAKIN_INCLUDE_DIR}/anakin_config.h. ")
include_directories(${ANAKIN_ROOT}/include)
include_directories(${ANAKIN_ROOT}/include/saber)
link_directories(${ANAKIN_ROOT})
add_definitions(-DPADDLE_WITH_ANAKIN)
endif()

@ -33,5 +33,6 @@ if(TENSORRT_FOUND)
message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. " message(STATUS "Current TensorRT header is ${TENSORRT_INCLUDE_DIR}/NvInfer.h. "
"Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ") "Current TensorRT version is v${TENSORRT_MAJOR_VERSION}. ")
include_directories(${TENSORRT_INCLUDE_DIR}) include_directories(${TENSORRT_INCLUDE_DIR})
link_directories(${TENSORRT_LIBRARY})
add_definitions(-DPADDLE_WITH_TENSORRT) add_definitions(-DPADDLE_WITH_TENSORRT)
endif() endif()

@ -520,6 +520,7 @@ paddle.fluid.unique_name.guard (ArgSpec(args=['new_generator'], varargs=None, ke
paddle.fluid.recordio_writer.convert_reader_to_recordio_file (ArgSpec(args=['filename', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None)), ('document', '65c7523e86f0c50bb729b01667f36310')) paddle.fluid.recordio_writer.convert_reader_to_recordio_file (ArgSpec(args=['filename', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None)), ('document', '65c7523e86f0c50bb729b01667f36310'))
paddle.fluid.recordio_writer.convert_reader_to_recordio_files (ArgSpec(args=['filename', 'batch_per_file', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None)), ('document', 'bc643f0f5f1b9db57ff0d8a57d379bd7')) paddle.fluid.recordio_writer.convert_reader_to_recordio_files (ArgSpec(args=['filename', 'batch_per_file', 'reader_creator', 'feeder', 'compressor', 'max_num_records', 'feed_order'], varargs=None, keywords=None, defaults=(Compressor.Snappy, 1000, None)), ('document', 'bc643f0f5f1b9db57ff0d8a57d379bd7'))
paddle.fluid.Scope Scope() -> paddle.fluid.core._Scope paddle.fluid.Scope Scope() -> paddle.fluid.core._Scope
paddle.fluid.install_check.run_check (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', '66b7c84a17ed32fec2df9628367be2b9'))
paddle.reader.cache (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '1676886070eb607cb608f7ba47be0d3c')) paddle.reader.cache (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '1676886070eb607cb608f7ba47be0d3c'))
paddle.reader.map_readers (ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None), ('document', '77cbadb09df588e21e5cc0819b69c87d')) paddle.reader.map_readers (ArgSpec(args=['func'], varargs='readers', keywords=None, defaults=None), ('document', '77cbadb09df588e21e5cc0819b69c87d'))
paddle.reader.buffered (ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None), ('document', '0d6186f109feceb99f60ec50a0a624cb')) paddle.reader.buffered (ArgSpec(args=['reader', 'size'], varargs=None, keywords=None, defaults=None), ('document', '0d6186f109feceb99f60ec50a0a624cb'))

@ -5,6 +5,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper) cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper) cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
@ -72,7 +73,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass) cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle) cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
@ -52,13 +53,28 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
// Note that must assert topology sort is stable // Note that must assert topology sort is stable
auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs); auto& ops = graph->Get<const std::vector<OpDesc*>>(kStaleProgramOpDescs);
for (auto* op_desc : ops) { for (auto* op_desc : ops) {
auto outputs = op_desc->Outputs(); try {
for (auto& o_it : outputs) { bool is_bk_op =
for (auto& v : o_it.second) { // values static_cast<bool>(boost::get<int>(op_desc->GetAttr(
vars[v] = order; OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
auto backward_vars =
boost::get<std::vector<std::string>>(op_desc->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
auto outputs = op_desc->Outputs();
for (auto& o_it : outputs) {
for (auto& v : o_it.second) { // values
vars[v] = order;
VLOG(1) << "in all_reduce_deps_pass:" << v;
}
} }
order++;
} catch (boost::bad_get e) {
} }
order++;
} }
std::vector<OpHandleBase*> dist_ops; std::vector<OpHandleBase*> dist_ops;

@ -0,0 +1,66 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
FetchBarrierOpHandle::FetchBarrierOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
// fetch_barrier op always run on place0, but output on all places.
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
local_scopes_(local_scopes),
places_(places),
run_scope_(local_scopes[0]),
place_(places[0]) {
for (auto &p : places) {
this->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p));
}
}
bool FetchBarrierOpHandle::IsMultiDeviceTransfer() {
// override IsMultiDeviceTransfer to return true
return true;
}
void FetchBarrierOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto run_func = [this]() {
op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
}
bool FetchBarrierOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->GeneratedOp() &&
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_.at(place_);
return need_wait;
}
std::string FetchBarrierOpHandle::Name() const { return op_->Type(); }
} // namespace details
} // namespace framework
} // namespace paddle

@ -0,0 +1,61 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
// **NOTE**: fetch_barrier op is special it outputs all recved variables on
// all places if there are multiple places, must init with
// multiple dev_ctxes_ !!!!
struct FetchBarrierOpHandle : public OpHandleBase {
public:
FetchBarrierOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
bool IsMultiDeviceTransfer() override;
std::string Name() const override;
protected:
void RunImpl() override;
bool NeedWait(VarHandleBase *in_var) override;
private:
std::unique_ptr<OperatorBase> op_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
Scope *run_scope_;
platform::Place place_;
bool is_lock_and_record_event_free_{false};
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -17,6 +17,8 @@
#include <deque> #include <deque>
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <queue>
#include <sstream>
#include <stack> #include <stack>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
@ -148,12 +150,14 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
view_.Build(graph.get()); view_.Build(graph.get());
InitSSAGraphNodes(); InitSSAGraphNodes();
auto cnt = 0;
for (auto* op : view_.AllOps()) { for (auto* op : view_.AllOps()) {
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name())) if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
continue; continue;
TryInplaceOpInputOutput(op, graph.get()); TryInplaceOpInputOutput(op, graph.get());
} }
graph->ResolveHazard(var_nodes_); // graph->ResolveHazard(var_nodes_);
return graph; return graph;
} }
@ -264,13 +268,10 @@ void InplacePass::WithdrawModify(const NodeSwapQueue& nodes,
void InplacePass::TryInplaceOpInputOutput(ir::Node* op, void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
ir::Graph* graph) const { ir::Graph* graph) const {
VLOG(4) << "Try to inplace op " << op->Name(); VLOG(4) << "Try to inplace op " << op->Name();
// FIXME(liuwei1031): Graph is not aware of the existence of BlockDescs and // PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
// ProgramDescs. // "op_desc is nullptr");
// The operations related to BlockDesc or ProgramDesc should perform on Graph
// or Node directly!
PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
"op_desc is nullptr");
// some pre-requirments need to meet if the op want to inplaced. // some pre-requirments need to meet if the op want to inplaced.
PADDLE_ENFORCE(op->Op() != nullptr, "op_desc is nullptr");
auto* op_desc = op->Op(); auto* op_desc = op->Op();
auto& infer_inplace = auto& infer_inplace =
@ -281,21 +282,58 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
PADDLE_ENFORCE(static_cast<bool>(infer_inplace), PADDLE_ENFORCE(static_cast<bool>(infer_inplace),
"%s's infer_inplace has not been registered", op_desc->Type()); "%s's infer_inplace has not been registered", op_desc->Type());
auto* block = op_desc->Block(); auto in_to_outs = infer_inplace(*op_desc);
auto in_to_outs = infer_inplace(*op_desc, block);
auto& all_ops = view_.AllOps(); auto& all_ops = view_.AllOps();
auto cursor = std::find(all_ops.begin(), all_ops.end(), op); auto cursor = std::find(all_ops.begin(), all_ops.end(), op);
size_t idx = std::distance(all_ops.begin(), cursor); size_t idx = std::distance(all_ops.begin(), cursor);
for (auto& pair : in_to_outs) { for (auto& pair : in_to_outs) {
auto& in_var_name = pair.first; auto& in_para_name = pair.first;
auto& out_var_name = pair.second; auto& out_para_name = pair.second;
auto input_vars = op->Op()->Input(in_para_name);
if (!input_vars.size()) {
VLOG(4) << "Parameter " << in_para_name << " is empty skip "
<< in_para_name << " => " << out_para_name << " pair";
continue;
}
auto output_vars = op->Op()->Output(out_para_name);
if (!output_vars.size()) {
VLOG(4) << "Parameter " << out_para_name << " is empty skip "
<< in_para_name << " => " << out_para_name << " pair";
continue;
}
auto in_var_name = input_vars.at(0);
auto out_var_name = output_vars.at(0);
auto* in_node = view_.GetNodeByName(in_var_name, op->inputs); auto* in_node = view_.GetNodeByName(in_var_name, op->inputs);
auto* out_node = view_.GetNodeByName(out_var_name, op->outputs); auto* out_node = view_.GetNodeByName(out_var_name, op->outputs);
VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name;
bool can_replace = true;
if (in_var_name == out_var_name) {
can_replace = false;
VLOG(4) << "SKIP: Input variable " << in_var_name << " & Output variable "
<< out_var_name << " are the same";
} else if (!NodeCanReused(in_node)) {
can_replace = false;
VLOG(4) << "SKIP: Input varialbe " << in_var_name << "cannot be reused";
} else if (!NodeCanReused(out_node)) {
can_replace = false;
VLOG(4) << "SKIP: Output variable " << out_var_name
<< " cannot be reused";
} else if (details::NodeSize(*in_node->Var()) !=
details::NodeSize(*out_node->Var())) {
can_replace = false;
VLOG(4) << "SKIP: Input and Output varialbe size not match";
}
if (!can_replace) continue;
// 2. there is no external pending op on the input node // 2. there is no external pending op on the input node
if (view_.PendingOpsOnVar(in_node).size() > 1) { // if (view_.PendingOpsOnVar(in_node).size() > 1) {
if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) {
VLOG(4) << string::Sprintf( VLOG(4) << string::Sprintf(
"Skiped pair %s => %s. %s input has external dependency." "Skiped pair %s => %s. %s input has external dependency."
"inplace such pair will overwrite the memory.", "inplace such pair will overwrite the memory.",
@ -342,6 +380,97 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
} }
} }
void GraphView::TopoSort(ir::Graph* graph) {
//
ops_.clear();
auto deps_num = [](ir::Node* op) {
auto cnt = 0;
for (auto& var : op->inputs)
if (var->inputs.size() > 0) ++cnt;
return cnt;
};
std::queue<std::pair<ir::Node*, uint32_t>> ready_ops;
int level = 0;
auto nodes = graph->Nodes();
std::unordered_map<ir::Node*, uint32_t> deps_map;
for (auto& node : nodes) {
if (node->IsOp() && node->Op() != nullptr) {
deps_map[node] = deps_num(node);
if (0 == deps_map[node]) {
ready_ops.push({node, level});
}
}
}
while (!ready_ops.empty()) {
auto item = ready_ops.front();
ready_ops.pop();
ops_.emplace_back(item.first);
// record level when pop from queue
op_level_[item.first] = item.second;
for (auto node : item.first->outputs) {
for (auto op : node->outputs) {
--deps_map[op];
if (deps_map[op] == 0) ready_ops.push({op, item.second + 1});
}
}
}
bool all_ops_checked = true;
for (auto& node : nodes) {
if (node->IsOp() && node->Op() != nullptr && deps_map[node] > 0) {
all_ops_checked = false;
break;
}
}
PADDLE_ENFORCE(all_ops_checked, "All ops deps should be 0 after analysis");
}
// return true if current op node depeneds on all other op that use the same
// variable node
bool GraphView::CheckDeps(ir::Node* var, ir::Node* current_op) const {
// get op list that rely on the same variable
auto op_list = var->outputs;
for (auto& op : op_list) {
if (op == current_op) continue;
VLOG(4) << " GraphView::CheckDeps : " << op->Name() << " & "
<< current_op->Name();
if (!CheckOpDeps(op, current_op)) return false;
VLOG(4) << "";
}
return true;
}
// check if op2 depends on op1's output
bool GraphView::CheckOpDeps(ir::Node* op1, ir::Node* op2) const {
auto print_op = [&](ir::Node* op, const char* name) {
std::ostringstream os;
os << " " << name << " : " << op->Name() << " ";
os << "Input args : ";
for (auto& arg : op->inputs) os << arg->Name() << " ";
os << "Output args : ";
for (auto& arg : op->outputs) os << arg->Name() << " ";
os << "Level : " << op_level_.at(op);
VLOG(4) << os.str();
};
print_op(op1, "OP1");
print_op(op2, "OP2");
if (op1 == op2) return true;
if (op_level_.at(op1) >= op_level_.at(op2)) return false;
for (auto& var : op2->inputs)
if (var->inputs.size() > 0 && CheckOpDeps(op1, var->inputs[0])) return true;
return false;
}
ir::Node* GraphView::GetNodeByName(const std::string& name, ir::Node* GraphView::GetNodeByName(const std::string& name,
const std::vector<ir::Node*>& nodes) const { const std::vector<ir::Node*>& nodes) const {
// nodes should be op->inputs/outputs // nodes should be op->inputs/outputs
@ -387,22 +516,7 @@ void GraphView::Build(ir::Graph* g) {
// Because we insert some new created node. Which may have data race between // Because we insert some new created node. Which may have data race between
// nodes. // nodes.
// resolve data harzards depends on the var nodes in right order. // resolve data harzards depends on the var nodes in right order.
ops_ = SortOpLikeDescOrder(*g); TopoSort(g);
// 1. track the nodes which reused previous node in Python memory optimize.
// these node can not be inplaced, otherwise may generate a circle in graph.
std::unordered_set<std::string> all_vars;
for (auto& node : g->Nodes()) {
if (node->IsVar()) continue;
for (auto& out : node->outputs) {
if (out->IsCtrlVar() || out->Var() == nullptr) continue;
if (all_vars.count(out->Name())) {
dup_nodes_.emplace(out->Name());
} else {
all_vars.emplace(out->Name());
}
}
}
// 2. track the nodes which used by parameter server. // 2. track the nodes which used by parameter server.
// these node can not be inplaced, otherwise trainer // these node can not be inplaced, otherwise trainer

@ -14,6 +14,7 @@
#pragma once #pragma once
#include <map> #include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
@ -50,10 +51,15 @@ class GraphView {
// map the parameter and gradient, must be skipped. // map the parameter and gradient, must be skipped.
bool InSkipSet(const std::string& var) const; bool InSkipSet(const std::string& var) const;
bool CheckDeps(ir::Node* var, ir::Node* current_op) const;
bool CheckOpDeps(ir::Node* op1, ir::Node* op2) const;
void TopoSort(ir::Graph* g);
private: private:
std::vector<ir::Node*> ops_; std::vector<ir::Node*> ops_;
std::unordered_set<std::string> dup_nodes_; // mem opt affect nodes std::unordered_set<std::string> dup_nodes_; // mem opt affect nodes
std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_; std::map<ir::Node*, std::unordered_set<ir::Node*>> adj_list_;
std::unordered_map<ir::Node*, uint32_t> op_level_;
}; };
// swap pairs in sequence // swap pairs in sequence

@ -190,7 +190,7 @@ struct NodeComparator {
auto rhs_shape = rhs_desc->GetShape(); auto rhs_shape = rhs_desc->GetShape();
if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) || if ((lhs_shape[0] == -1 && rhs_shape[0] == -1) ||
(lhs_shape[0] != -1 && rhs_shape[0] != -1)) { (lhs_shape[0] != -1 && rhs_shape[0] != -1)) {
return NodeSize(lhs) <= NodeSize(rhs); return NodeSize(lhs) == NodeSize(rhs);
} else { } else {
return false; return false;
} }
@ -449,6 +449,7 @@ void ControlFlowGraph::LiveVariableAnalysis() {
live_in_[op].insert(var); live_in_[op].insert(var);
} }
for (auto& var : defs_[op]) { for (auto& var : defs_[op]) {
if (uses_[op].count(var)) continue;
live_in_[op].erase(var); live_in_[op].erase(var);
} }

@ -142,15 +142,16 @@ TEST(OrderedSet, FindBestFitNode) {
for (auto& node : nodes) { for (auto& node : nodes) {
pool.Insert(node.get()); pool.Insert(node.get());
} }
// FIXME(liuwei1031) this API has changed,
// disable these tests temporarily
// FindNextBestFitNode // FindNextBestFitNode
auto* n = nodes[0].get(); // auto* n = nodes[0].get();
auto* cache = pool.FindBestFitNode(n); // auto* cache = pool.FindBestFitNode(n);
PADDLE_ENFORCE(cache->Name() == "a"); // PADDLE_ENFORCE(cache->Name() == "a");
cache = pool.FindNextBestFitNode(n, cache); // cache = pool.FindNextBestFitNode(n, cache);
PADDLE_ENFORCE(cache->Name() == "c"); // PADDLE_ENFORCE(cache->Name() == "c");
cache = pool.FindNextBestFitNode(n, cache); // cache = pool.FindNextBestFitNode(n, cache);
PADDLE_ENFORCE(cache->Name() == "b"); // PADDLE_ENFORCE(cache->Name() == "b");
} }
} // namespace details } // namespace details

@ -23,6 +23,7 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h" #include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
@ -851,9 +852,17 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
node->Op()->Type()); node->Op()->Type());
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], // Create fetch_barrier op handle to enable output on all devices.
node->Op()->Type(), places_[op_dev_id])); // **NOTE** fetch_barrier should output variables list same as recv op does.
if (node->Op()->Type() == "fetch_barrier") {
result->Get<GraphOps>(kGraphOps).emplace_back(new FetchBarrierOpHandle(
result->CreateOpNode(node->Op()), local_scopes_, places_));
} else {
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id]));
}
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
CreateOpHandleIOs(result, node, op_dev_id); CreateOpHandleIOs(result, node, op_dev_id);

@ -55,7 +55,7 @@ void OpHandleBase::Run(bool use_cuda) {
if (out_var_handle) { if (out_var_handle) {
int dev_id = int dev_id =
boost::get<platform::CUDAPlace>(out_var_handle->place()).device; boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
out_var_handle->SetGenerateEvent(events_[dev_id]); out_var_handle->SetGenerateEvent(events_.at(dev_id));
} }
} }
} else { } else {
@ -71,7 +71,7 @@ void OpHandleBase::Run(bool use_cuda) {
"The place of input(%s) is not consistent with the " "The place of input(%s) is not consistent with the "
"place of current op(%s).", "place of current op(%s).",
out_var_handle->Name(), Name()); out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_[dev_id]); out_var_handle->SetGenerateEvent(events_.at(dev_id));
} }
} }
} }

@ -209,9 +209,9 @@ struct OpInfoFiller<T, kShapeInference> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kInplaceOpInference> { struct OpInfoFiller<T, kInplaceOpInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_inplace_ = [](const OpDesc& op_desc, BlockDesc* block) { info->infer_inplace_ = [](const OpDesc& op_desc) {
T infer; T infer;
return infer(op_desc, block); return infer(op_desc);
}; };
} }
}; };

@ -17,8 +17,8 @@
#include <numeric> #include <numeric>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/details/memory_optimize_helper.h" #include "paddle/fluid/framework/details/memory_optimize_helper.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
@ -32,55 +32,22 @@ namespace framework {
then Out will inplaced use X's memory. The base class will do then Out will inplaced use X's memory. The base class will do
legality validation for both variables. legality validation for both variables.
*/ */
class InplaceOpInference { class InplaceOpInference {
public: public:
virtual ~InplaceOpInference() {} virtual ~InplaceOpInference() {}
virtual std::unordered_map<std::string, std::string> operator()( virtual std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, BlockDesc* block) const = 0; const OpDesc& op_desc) const = 0;
};
class InplaceInToOut : public InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, BlockDesc* block) const {
std::unordered_map<std::string, std::string> ret;
auto in_out_var_names_pair = this->Apply(op_desc, block);
for (auto& pair : in_out_var_names_pair) {
PADDLE_ENFORCE(!op_desc.Input(pair.first).empty(),
string::Sprintf("op %s do not have input of %s!",
op_desc.Type(), pair.first));
PADDLE_ENFORCE(!op_desc.Output(pair.second).empty(),
string::Sprintf("op %s do not have output of %s!",
op_desc.Type(), pair.second));
auto& in_name = op_desc.Input(pair.first).at(0);
auto& out_name = op_desc.Output(pair.second).at(0);
auto in = block->FindRecursiveOrCreateVar(in_name);
auto out = block->FindRecursiveOrCreateVar(out_name);
if (TryInplaceInputOutput(in, out)) ret.insert({in_name, out_name});
}
return ret;
}
protected:
virtual std::unordered_map<std::string, std::string> Apply(
const OpDesc& op_desc, BlockDesc* block) const = 0;
bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const {
return in.Name() != out.Name() && details::NodeCanReused(in) &&
details::NodeCanReused(out) &&
details::NodeSize(out) <= details::NodeSize(in);
}
}; };
/* /*
Inplace In and Out for operator only have an Input and an Output. Inplace In and Out for operator only have an Input and an Output.
For example, activation op. For example, activation op.
*/ */
class SingleOpInplaceInToOut : public InplaceInToOut { class SingleOpInplaceInToOut : public InplaceOpInference {
protected: public:
std::unordered_map<std::string, std::string> Apply( std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, BlockDesc* block) const override { const OpDesc& op_desc) const override {
PADDLE_ENFORCE(!op_desc.InputNames().empty(), PADDLE_ENFORCE(!op_desc.InputNames().empty(),
"Op inputs must not be empty"); "Op inputs must not be empty");
PADDLE_ENFORCE(!op_desc.OutputNames().empty(), PADDLE_ENFORCE(!op_desc.OutputNames().empty(),
@ -95,10 +62,10 @@ class SingleOpInplaceInToOut : public InplaceInToOut {
Gradient op. Inplace output use it's Input. Gradient op. Inplace output use it's Input.
For example, Input@Grad->Input reuse strategy. For example, Input@Grad->Input reuse strategy.
*/ */
class GradOpInplaceInToOut : public InplaceInToOut { class GradOpInplaceInToOut : public InplaceOpInference {
protected: public:
std::unordered_map<std::string, std::string> Apply( std::unordered_map<std::string, std::string> operator()(
const OpDesc& op_desc, BlockDesc* block) const override { const OpDesc& op_desc) const override {
std::unordered_map<std::string, std::string> ret; std::unordered_map<std::string, std::string> ret;
std::unordered_set<std::string> output_names(op_desc.OutputNames().begin(), std::unordered_set<std::string> output_names(op_desc.OutputNames().begin(),
op_desc.OutputNames().end()); op_desc.OutputNames().end());

File diff suppressed because it is too large Load Diff

@ -46,9 +46,6 @@ cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
pass_library(graph_to_program_pass base) pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base) pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base) pass_library(lock_free_optimize_pass base)
pass_library(cpu_quantize_placement_pass base)
pass_library(cpu_quantize_pass inference)
pass_library(cpu_quantize_squash_pass inference)
pass_library(fc_fuse_pass inference) pass_library(fc_fuse_pass inference)
pass_library(attention_lstm_fuse_pass inference) pass_library(attention_lstm_fuse_pass inference)
pass_library(infer_clean_graph_pass inference) pass_library(infer_clean_graph_pass inference)
@ -71,22 +68,31 @@ pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_scale_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base) pass_library(runtime_context_cache_pass base)
pass_library(simplify_anakin_detection_pattern_pass inference)
pass_library(anakin_fillconstant_elementwisemul_fuse inference)
# There may be many transpose-flatten structures in a model, and the output of # There may be many transpose-flatten structures in a model, and the output of
# these structures will be used as inputs to the concat Op. This pattern will # these structures will be used as inputs to the concat Op. This pattern will
# be detected by our pass. The index here represents the number of structures in the # be detected by our pass. The index here represents the number of structures in the
# pattern. We use index 3 ~ 6, because these quantities of structures are # pattern. We use index 3 ~ 6, because these quantities of structures are
# common in the models. # common in the models.
foreach (index RANGE 3 6) foreach (index RANGE 2 6)
file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n") file(APPEND ${pass_file} "USE_PASS(transpose_flatten${index}_concat_fuse_pass);\n")
endforeach() endforeach()
foreach (index RANGE 2 6)
file(APPEND ${pass_file} "USE_PASS(simplify_anakin_detection_pattern_pass${index});\n")
endforeach()
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base mkldnn) pass_library(mkldnn_placement_pass base mkldnn)
pass_library(depthwise_conv_mkldnn_pass base mkldnn) pass_library(depthwise_conv_mkldnn_pass base mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn)
pass_library(cpu_quantize_placement_pass base mkldnn)
pass_library(cpu_quantize_pass inference mkldnn)
pass_library(cpu_quantize_squash_pass inference mkldnn)
endif() endif()
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
@ -105,9 +111,6 @@ cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS g
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto) cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass) cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
cc_test(test_cpu_quantize_placement_pass SRCS cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
if(NOT WIN32) if(NOT WIN32)
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass) cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
endif() endif()
@ -117,4 +120,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
endif () endif ()

@ -0,0 +1,85 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(fill_constant); \
GET_IR_NODE(fill_constant_out); \
GET_IR_NODE(elementwise_mul); \
GET_IR_NODE(elementwise_mul_out);
std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
FusePassBase::Init(pattern_name, graph.get());
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("elementwise_mul", "X")
->AsInput();
patterns::AnakinFillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(),
pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
PADDLE_ENFORCE(subgraph.count(x));
auto* elementwise_in = subgraph.at(x);
float constant_value =
boost::get<float>(fill_constant->Op()->GetAttr("value"));
framework::OpDesc new_op_desc;
new_op_desc.SetType("scale");
new_op_desc.SetInput("X", {elementwise_in->Name()});
new_op_desc.SetAttr("scale", constant_value);
new_op_desc.SetAttr("bias", static_cast<float>(0.0));
new_op_desc.SetAttr("bias_after_scale", true);
new_op_desc.SetOutput("Out", {elementwise_mul_out->Name()});
new_op_desc.Flush();
// Create a new node for the fused op.
auto* scale_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(elementwise_in, scale_op); // Input
IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph.get(),
{fill_constant, fill_constant_out, elementwise_mul});
};
gpd(graph.get(), handler);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(anakin_fillconstant_elementwisemul_fuse,
paddle::framework::ir::AnakinFillconstantElementwisemulFuse);

@ -0,0 +1,35 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class AnakinFillconstantElementwisemulFuse : public FusePassBase {
public:
virtual ~AnakinFillconstantElementwisemulFuse() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle

@ -1470,6 +1470,171 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return concat_out; return concat_out;
} }
PDNode *patterns::AnakinDetectionPattern::operator()(
std::vector<PDNode *> conv_in, int times) {
// The times represents the repeat times of the
// {prior_box, prior_box_loc_out, flatten, prior_box_var_out, reshape}
const int kNumFields = 7;
const int kPriorBoxLocOffset = 1;
const int kReshape1Offset = 2;
const int kReshape1OutOffset = 3;
const int kPriorBoxVarOffset = 4;
const int kReshape2Offset = 5;
const int kReshape2OutOffset = 6;
const int kBoxCoderThirdInputOffset = times;
const int kMultiClassSecondInputNmsOffset = times + 1;
std::vector<PDNode *> nodes;
for (int i = 0; i < times; i++) {
nodes.push_back(
pattern->NewNode(GetNodeName("prior_box" + std::to_string(i)))
->assert_is_op("density_prior_box"));
nodes.push_back(pattern->NewNode(GetNodeName("box_out" + std::to_string(i)))
->assert_is_op_output("density_prior_box", "Boxes")
->assert_is_op_input("reshape2", "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("reshape1" + std::to_string(i)))
->assert_is_op("reshape2"));
nodes.push_back(
pattern->NewNode(GetNodeName("reshape1_out" + std::to_string(i)))
->assert_is_op_output("reshape2")
->assert_is_op_nth_input("concat", "X", i)
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("box_var_out" + std::to_string(i)))
->assert_is_op_output("density_prior_box", "Variances")
->assert_is_op_input("reshape2", "X")
->AsIntermediate());
nodes.push_back(
pattern->NewNode(GetNodeName("reshape2" + std::to_string(i)))
->assert_is_op("reshape2"));
nodes.push_back(
pattern->NewNode(GetNodeName("reshape2_out" + std::to_string(i)))
->assert_is_op_output("reshape2")
->assert_is_op_nth_input("concat", "X", i)
->AsIntermediate());
}
auto concat_op1 = pattern->NewNode(GetNodeName("concat1"))
->assert_is_op("concat")
->assert_op_has_n_inputs("concat", times);
auto concat_out1 = pattern->NewNode(GetNodeName("concat1_out"))
->assert_is_op_output("concat")
->AsIntermediate();
auto concat_op2 = pattern->NewNode(GetNodeName("concat2"))
->assert_is_op("concat")
->assert_op_has_n_inputs("concat", times);
auto concat_out2 = pattern->NewNode(GetNodeName("concat2_out"))
->assert_is_op_output("concat")
->AsIntermediate();
auto box_coder_op = pattern->NewNode(GetNodeName("box_coder"))
->assert_is_op("box_coder")
->assert_op_has_n_inputs("box_coder", 3);
auto box_coder_out = pattern->NewNode(GetNodeName("box_coder_out"))
->assert_is_op_output("box_coder")
->AsIntermediate();
auto transpose_before_nms =
pattern->NewNode(GetNodeName("transpose_before_nms"))
->assert_is_op("transpose2");
auto transpose_before_nms_out =
pattern->NewNode(GetNodeName("transpose_before_nms_out"))
->assert_is_op_output("transpose2")
->assert_is_op_input("multiclass_nms", "Scores")
->AsIntermediate();
auto multiclass_nms_op = pattern->NewNode(GetNodeName("multiclass_nms"))
->assert_is_op("multiclass_nms")
->assert_op_has_n_inputs("multiclass_nms", 2);
auto multiclass_nms_out = pattern->NewNode(GetNodeName("multiclass_nms_out"))
->assert_is_op_output("multiclass_nms")
->AsOutput();
std::vector<PDNode *> reshape1_outs;
std::vector<PDNode *> reshape2_outs;
for (int i = 0; i < times; i++) {
conv_in[i]->AsInput();
// prior_box
nodes[i * kNumFields]->LinksFrom({conv_in[i]});
// prior_box box out
nodes[i * kNumFields + kPriorBoxLocOffset]->LinksFrom(
{nodes[i * kNumFields]});
// reshape
nodes[i * kNumFields + kReshape1Offset]->LinksFrom(
{nodes[i * kNumFields + kPriorBoxLocOffset]});
// reshape_out
nodes[i * kNumFields + kReshape1OutOffset]->LinksFrom(
{nodes[i * kNumFields + kReshape1Offset]});
nodes[i * kNumFields + kPriorBoxVarOffset]->LinksFrom(
{nodes[i * kNumFields]});
// reshape
nodes[i * kNumFields + kReshape2Offset]->LinksFrom(
{nodes[i * kNumFields + kPriorBoxVarOffset]});
// reshape_out
nodes[i * kNumFields + kReshape2OutOffset]->LinksFrom(
{nodes[i * kNumFields + kReshape2Offset]});
reshape1_outs.push_back(nodes[i * kNumFields + kReshape1OutOffset]);
reshape2_outs.push_back(nodes[i * kNumFields + kReshape2OutOffset]);
}
concat_op1->LinksFrom(reshape1_outs);
concat_op2->LinksFrom(reshape2_outs);
concat_out1->LinksFrom({concat_op1});
concat_out2->LinksFrom({concat_op2});
conv_in[kBoxCoderThirdInputOffset]->AsInput();
conv_in[kMultiClassSecondInputNmsOffset]->AsInput();
box_coder_op->LinksFrom(
{concat_out1, concat_out2, conv_in[kBoxCoderThirdInputOffset]});
box_coder_out->LinksFrom({box_coder_op});
transpose_before_nms->LinksFrom({conv_in[kMultiClassSecondInputNmsOffset]});
transpose_before_nms_out->LinksFrom({transpose_before_nms});
multiclass_nms_op->LinksFrom({box_coder_out, transpose_before_nms_out})
.LinksTo({multiclass_nms_out});
return multiclass_nms_out;
}
PDNode *patterns::AnakinFillConstantElementWiseMulFuse::operator()(
PDNode *elementwise_op_input) {
auto fill_constant =
pattern->NewNode(fill_constant_repr())->assert_is_op("fill_constant");
auto fill_constant_out = pattern->NewNode(fill_constant_out_repr())
->assert_is_op_output("fill_constant")
->assert_is_op_input("elementwise_mul", "Y")
->AsIntermediate();
auto elementwise_mul_op =
pattern->NewNode(elementwise_mul_repr())->assert_is_op("elementwise_mul");
auto elementwise_mul_out = pattern->NewNode(elementwise_mul_out_repr())
->assert_is_op_output("elementwise_mul")
->AsOutput();
fill_constant_out->LinksFrom({fill_constant});
elementwise_mul_op->LinksFrom({elementwise_op_input, fill_constant_out});
elementwise_mul_out->LinksFrom({elementwise_mul_op});
return elementwise_mul_out;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -844,6 +844,36 @@ struct TransposeFlattenConcat : public PatternBase {
} }
}; };
struct AnakinDetectionPattern : public PatternBase {
AnakinDetectionPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "anakin_detect_pattern") {}
PDNode* operator()(std::vector<PDNode*> conv_inputs, int times);
std::string GetNodeName(const std::string& op_type) {
return PDNodeName(name_scope_, repr_, id_, op_type);
}
PDNode* GetPDNode(const std::string& op_type) {
return pattern->RetrieveNode(GetNodeName(op_type));
}
};
struct AnakinFillConstantElementWiseMulFuse : public PatternBase {
AnakinFillConstantElementWiseMulFuse(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"anakin_fillconstant_elementwisemul_fuse") {}
PDNode* operator()(PDNode* elementwise_op_input);
// declare operator node's name
PATTERN_DECL_NODE(fill_constant);
PATTERN_DECL_NODE(fill_constant_out);
PATTERN_DECL_NODE(elementwise_mul);
PATTERN_DECL_NODE(elementwise_mul_out);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/cpu_quantize_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save