|
|
|
|
@ -17,6 +17,8 @@
|
|
|
|
|
#include <deque>
|
|
|
|
|
#include <iterator>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <queue>
|
|
|
|
|
#include <sstream>
|
|
|
|
|
#include <stack>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
@ -148,12 +150,14 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
|
|
|
|
|
view_.Build(graph.get());
|
|
|
|
|
InitSSAGraphNodes();
|
|
|
|
|
|
|
|
|
|
auto cnt = 0;
|
|
|
|
|
for (auto* op : view_.AllOps()) {
|
|
|
|
|
VLOG(4) << "Handle op " << cnt++ << ": " << op->Name();
|
|
|
|
|
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
|
|
|
|
|
continue;
|
|
|
|
|
TryInplaceOpInputOutput(op, graph.get());
|
|
|
|
|
}
|
|
|
|
|
graph->ResolveHazard(var_nodes_);
|
|
|
|
|
// graph->ResolveHazard(var_nodes_);
|
|
|
|
|
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
@ -264,13 +268,10 @@ void InplacePass::WithdrawModify(const NodeSwapQueue& nodes,
|
|
|
|
|
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
|
|
|
|
|
ir::Graph* graph) const {
|
|
|
|
|
VLOG(4) << "Try to inplace op " << op->Name();
|
|
|
|
|
// FIXME(liuwei1031): Graph is not aware of the existence of BlockDescs and
|
|
|
|
|
// ProgramDescs.
|
|
|
|
|
// The operations related to BlockDesc or ProgramDesc should perform on Graph
|
|
|
|
|
// or Node directly!
|
|
|
|
|
PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
|
|
|
|
|
"op_desc is nullptr");
|
|
|
|
|
// PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
|
|
|
|
|
// "op_desc is nullptr");
|
|
|
|
|
// some pre-requirments need to meet if the op want to inplaced.
|
|
|
|
|
PADDLE_ENFORCE(op->Op() != nullptr, "op_desc is nullptr");
|
|
|
|
|
|
|
|
|
|
auto* op_desc = op->Op();
|
|
|
|
|
auto& infer_inplace =
|
|
|
|
|
@ -281,21 +282,58 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
|
|
|
|
|
PADDLE_ENFORCE(static_cast<bool>(infer_inplace),
|
|
|
|
|
"%s's infer_inplace has not been registered", op_desc->Type());
|
|
|
|
|
|
|
|
|
|
auto* block = op_desc->Block();
|
|
|
|
|
auto in_to_outs = infer_inplace(*op_desc, block);
|
|
|
|
|
auto in_to_outs = infer_inplace(*op_desc);
|
|
|
|
|
|
|
|
|
|
auto& all_ops = view_.AllOps();
|
|
|
|
|
auto cursor = std::find(all_ops.begin(), all_ops.end(), op);
|
|
|
|
|
size_t idx = std::distance(all_ops.begin(), cursor);
|
|
|
|
|
|
|
|
|
|
for (auto& pair : in_to_outs) {
|
|
|
|
|
auto& in_var_name = pair.first;
|
|
|
|
|
auto& out_var_name = pair.second;
|
|
|
|
|
auto& in_para_name = pair.first;
|
|
|
|
|
auto& out_para_name = pair.second;
|
|
|
|
|
|
|
|
|
|
auto input_vars = op->Op()->Input(in_para_name);
|
|
|
|
|
if (!input_vars.size()) {
|
|
|
|
|
VLOG(4) << "Parameter " << in_para_name << " is empty skip "
|
|
|
|
|
<< in_para_name << " => " << out_para_name << " pair";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto output_vars = op->Op()->Output(out_para_name);
|
|
|
|
|
if (!output_vars.size()) {
|
|
|
|
|
VLOG(4) << "Parameter " << out_para_name << " is empty skip "
|
|
|
|
|
<< in_para_name << " => " << out_para_name << " pair";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto in_var_name = input_vars.at(0);
|
|
|
|
|
auto out_var_name = output_vars.at(0);
|
|
|
|
|
auto* in_node = view_.GetNodeByName(in_var_name, op->inputs);
|
|
|
|
|
auto* out_node = view_.GetNodeByName(out_var_name, op->outputs);
|
|
|
|
|
|
|
|
|
|
VLOG(4) << "Try to inplace " << in_var_name << " with " << out_var_name;
|
|
|
|
|
|
|
|
|
|
bool can_replace = true;
|
|
|
|
|
if (in_var_name == out_var_name) {
|
|
|
|
|
can_replace = false;
|
|
|
|
|
VLOG(4) << "SKIP: Input variable " << in_var_name << " & Output variable "
|
|
|
|
|
<< out_var_name << " are the same";
|
|
|
|
|
} else if (!NodeCanReused(in_node)) {
|
|
|
|
|
can_replace = false;
|
|
|
|
|
VLOG(4) << "SKIP: Input varialbe " << in_var_name << "cannot be reused";
|
|
|
|
|
} else if (!NodeCanReused(out_node)) {
|
|
|
|
|
can_replace = false;
|
|
|
|
|
VLOG(4) << "SKIP: Output variable " << out_var_name
|
|
|
|
|
<< " cannot be reused";
|
|
|
|
|
} else if (details::NodeSize(*in_node->Var()) !=
|
|
|
|
|
details::NodeSize(*out_node->Var())) {
|
|
|
|
|
can_replace = false;
|
|
|
|
|
VLOG(4) << "SKIP: Input and Output varialbe size not match";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!can_replace) continue;
|
|
|
|
|
|
|
|
|
|
// 2. there is no external pending op on the input node
|
|
|
|
|
if (view_.PendingOpsOnVar(in_node).size() > 1) {
|
|
|
|
|
// if (view_.PendingOpsOnVar(in_node).size() > 1) {
|
|
|
|
|
if (in_node->outputs.size() > 1 && !view_.CheckDeps(in_node, op)) {
|
|
|
|
|
VLOG(4) << string::Sprintf(
|
|
|
|
|
"Skiped pair %s => %s. %s input has external dependency."
|
|
|
|
|
"inplace such pair will overwrite the memory.",
|
|
|
|
|
@ -342,6 +380,97 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GraphView::TopoSort(ir::Graph* graph) {
|
|
|
|
|
//
|
|
|
|
|
ops_.clear();
|
|
|
|
|
auto deps_num = [](ir::Node* op) {
|
|
|
|
|
auto cnt = 0;
|
|
|
|
|
for (auto& var : op->inputs)
|
|
|
|
|
if (var->inputs.size() > 0) ++cnt;
|
|
|
|
|
return cnt;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
std::queue<std::pair<ir::Node*, uint32_t>> ready_ops;
|
|
|
|
|
|
|
|
|
|
int level = 0;
|
|
|
|
|
auto nodes = graph->Nodes();
|
|
|
|
|
std::unordered_map<ir::Node*, uint32_t> deps_map;
|
|
|
|
|
for (auto& node : nodes) {
|
|
|
|
|
if (node->IsOp() && node->Op() != nullptr) {
|
|
|
|
|
deps_map[node] = deps_num(node);
|
|
|
|
|
if (0 == deps_map[node]) {
|
|
|
|
|
ready_ops.push({node, level});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while (!ready_ops.empty()) {
|
|
|
|
|
auto item = ready_ops.front();
|
|
|
|
|
ready_ops.pop();
|
|
|
|
|
|
|
|
|
|
ops_.emplace_back(item.first);
|
|
|
|
|
// record level when pop from queue
|
|
|
|
|
op_level_[item.first] = item.second;
|
|
|
|
|
|
|
|
|
|
for (auto node : item.first->outputs) {
|
|
|
|
|
for (auto op : node->outputs) {
|
|
|
|
|
--deps_map[op];
|
|
|
|
|
if (deps_map[op] == 0) ready_ops.push({op, item.second + 1});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool all_ops_checked = true;
|
|
|
|
|
for (auto& node : nodes) {
|
|
|
|
|
if (node->IsOp() && node->Op() != nullptr && deps_map[node] > 0) {
|
|
|
|
|
all_ops_checked = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(all_ops_checked, "All ops deps should be 0 after analysis");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// return true if current op node depeneds on all other op that use the same
|
|
|
|
|
// variable node
|
|
|
|
|
bool GraphView::CheckDeps(ir::Node* var, ir::Node* current_op) const {
|
|
|
|
|
// get op list that rely on the same variable
|
|
|
|
|
auto op_list = var->outputs;
|
|
|
|
|
for (auto& op : op_list) {
|
|
|
|
|
if (op == current_op) continue;
|
|
|
|
|
|
|
|
|
|
VLOG(4) << " GraphView::CheckDeps : " << op->Name() << " & "
|
|
|
|
|
<< current_op->Name();
|
|
|
|
|
if (!CheckOpDeps(op, current_op)) return false;
|
|
|
|
|
VLOG(4) << "";
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// check if op2 depends on op1's output
|
|
|
|
|
bool GraphView::CheckOpDeps(ir::Node* op1, ir::Node* op2) const {
|
|
|
|
|
auto print_op = [&](ir::Node* op, const char* name) {
|
|
|
|
|
std::ostringstream os;
|
|
|
|
|
os << " " << name << " : " << op->Name() << " ";
|
|
|
|
|
os << "Input args : ";
|
|
|
|
|
for (auto& arg : op->inputs) os << arg->Name() << " ";
|
|
|
|
|
os << "Output args : ";
|
|
|
|
|
for (auto& arg : op->outputs) os << arg->Name() << " ";
|
|
|
|
|
os << "Level : " << op_level_.at(op);
|
|
|
|
|
VLOG(4) << os.str();
|
|
|
|
|
};
|
|
|
|
|
print_op(op1, "OP1");
|
|
|
|
|
print_op(op2, "OP2");
|
|
|
|
|
|
|
|
|
|
if (op1 == op2) return true;
|
|
|
|
|
if (op_level_.at(op1) >= op_level_.at(op2)) return false;
|
|
|
|
|
|
|
|
|
|
for (auto& var : op2->inputs)
|
|
|
|
|
if (var->inputs.size() > 0 && CheckOpDeps(op1, var->inputs[0])) return true;
|
|
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::Node* GraphView::GetNodeByName(const std::string& name,
|
|
|
|
|
const std::vector<ir::Node*>& nodes) const {
|
|
|
|
|
// nodes should be op->inputs/outputs
|
|
|
|
|
@ -387,22 +516,7 @@ void GraphView::Build(ir::Graph* g) {
|
|
|
|
|
// Because we insert some new created node. Which may have data race between
|
|
|
|
|
// nodes.
|
|
|
|
|
// resolve data harzards depends on the var nodes in right order.
|
|
|
|
|
ops_ = SortOpLikeDescOrder(*g);
|
|
|
|
|
|
|
|
|
|
// 1. track the nodes which reused previous node in Python memory optimize.
|
|
|
|
|
// these node can not be inplaced, otherwise may generate a circle in graph.
|
|
|
|
|
std::unordered_set<std::string> all_vars;
|
|
|
|
|
for (auto& node : g->Nodes()) {
|
|
|
|
|
if (node->IsVar()) continue;
|
|
|
|
|
for (auto& out : node->outputs) {
|
|
|
|
|
if (out->IsCtrlVar() || out->Var() == nullptr) continue;
|
|
|
|
|
if (all_vars.count(out->Name())) {
|
|
|
|
|
dup_nodes_.emplace(out->Name());
|
|
|
|
|
} else {
|
|
|
|
|
all_vars.emplace(out->Name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
TopoSort(g);
|
|
|
|
|
|
|
|
|
|
// 2. track the nodes which used by parameter server.
|
|
|
|
|
// these node can not be inplaced, otherwise trainer
|
|
|
|
|
|