|
|
|
@ -199,15 +199,17 @@ void InplacePass::InplaceModifyDesc(const std::string& var,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const SSANodeVector InplacePass::TryInplaceModifyVar(
|
|
|
|
|
const std::string& var, const std::string& cache_var, const size_t& idx,
|
|
|
|
|
const SSANodePair InplacePass::TryInplaceModifyVar(const std::string& var,
|
|
|
|
|
const std::string& cache_var,
|
|
|
|
|
const size_t& idx,
|
|
|
|
|
ir::Graph* graph) const {
|
|
|
|
|
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
|
|
|
|
|
var_nodes_[var].at(0)->Var() != nullptr);
|
|
|
|
|
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
|
|
|
|
|
var_desc->SetName(cache_var);
|
|
|
|
|
|
|
|
|
|
SSANodeVector swap_nodes;
|
|
|
|
|
SSANodePair swap_nodes;
|
|
|
|
|
|
|
|
|
|
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
|
|
|
|
|
auto* op = view_.AllOps()[i];
|
|
|
|
|
|
|
|
|
@ -215,6 +217,7 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
|
|
|
|
|
for (auto* node : op->inputs) {
|
|
|
|
|
if (node->Name() == var) {
|
|
|
|
|
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
|
|
|
|
|
|
|
|
|
|
// swap node to cache_node
|
|
|
|
|
cache_node->outputs.insert(cache_node->outputs.end(),
|
|
|
|
|
node->outputs.begin(), node->outputs.end());
|
|
|
|
@ -228,13 +231,15 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
|
|
|
|
|
cache_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
swap_nodes[node].emplace_back(cache_node);
|
|
|
|
|
swap_nodes.emplace_back(std::make_pair(node, cache_node));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if we need to rename the output,
|
|
|
|
|
// always create a newer version of cache_var
|
|
|
|
|
for (auto* node : op->outputs) {
|
|
|
|
|
if (node->Name() == var) {
|
|
|
|
|
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
|
|
|
|
|
var_nodes_[cache_var].emplace_back(cache_node);
|
|
|
|
|
// swap node to cache node
|
|
|
|
|
cache_node->outputs.insert(cache_node->outputs.end(),
|
|
|
|
|
node->outputs.begin(), node->outputs.end());
|
|
|
|
@ -244,35 +249,35 @@ const SSANodeVector InplacePass::TryInplaceModifyVar(
|
|
|
|
|
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
|
|
|
|
|
cache_node);
|
|
|
|
|
}
|
|
|
|
|
swap_nodes[node].emplace_back(cache_node);
|
|
|
|
|
|
|
|
|
|
swap_nodes.emplace_back(std::make_pair(node, cache_node));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return swap_nodes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InplacePass::CommitModify(const SSANodeVector& swap_nodes,
|
|
|
|
|
void InplacePass::CommitModify(const SSANodePair& swap_nodes,
|
|
|
|
|
ir::Graph* graph) const {
|
|
|
|
|
for (auto& pair : swap_nodes) {
|
|
|
|
|
auto* node = pair.first;
|
|
|
|
|
const std::string var = node->Name();
|
|
|
|
|
for (auto* cache_node : pair.second) {
|
|
|
|
|
const std::string cache_var = cache_node->Name();
|
|
|
|
|
auto *node = pair.first, *cache_node = pair.second;
|
|
|
|
|
const std::string var = node->Name(), cache_var = cache_node->Name();
|
|
|
|
|
var_nodes_[cache_var].emplace_back(cache_node);
|
|
|
|
|
}
|
|
|
|
|
graph->RemoveNode(node);
|
|
|
|
|
auto& nodes = var_nodes_.at(var);
|
|
|
|
|
// release unused var in graph. Because python side memory optimize
|
|
|
|
|
// may reused the var in same name, so we only clear the var node
|
|
|
|
|
// after current inplaced index.
|
|
|
|
|
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
|
|
|
|
|
graph->RemoveNode(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InplacePass::WithDrawModify(const SSANodeVector& nodes,
|
|
|
|
|
void InplacePass::WithdrawModify(const SSANodePair& nodes,
|
|
|
|
|
ir::Graph* graph) const {
|
|
|
|
|
for (auto& pair : nodes) {
|
|
|
|
|
auto* node = pair.first;
|
|
|
|
|
const std::string var = node->Name();
|
|
|
|
|
for (auto* cache_node : pair.second) {
|
|
|
|
|
const std::string cache_var = cache_node->Name();
|
|
|
|
|
auto *node = pair.first, *cache_node = pair.second;
|
|
|
|
|
const std::string var = node->Name(), cache_var = cache_node->Name();
|
|
|
|
|
auto* prev_op = node->inputs[0];
|
|
|
|
|
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), cache_node,
|
|
|
|
|
node);
|
|
|
|
@ -283,71 +288,6 @@ void InplacePass::WithDrawModify(const SSANodeVector& nodes,
|
|
|
|
|
graph->RemoveNode(cache_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InplacePass::InplaceModifyVar(const std::string& var,
|
|
|
|
|
const std::string& cache_var,
|
|
|
|
|
const size_t& idx, ir::Graph* graph) const {
|
|
|
|
|
PADDLE_ENFORCE(var_nodes_[var].size() >= 1 &&
|
|
|
|
|
var_nodes_[var].at(0)->Var() != nullptr);
|
|
|
|
|
std::unique_ptr<VarDesc> var_desc(new VarDesc(*var_nodes_[var].at(0)->Var()));
|
|
|
|
|
var_desc->SetName(cache_var);
|
|
|
|
|
|
|
|
|
|
for (size_t i = idx; i < view_.AllOps().size(); ++i) {
|
|
|
|
|
auto* op = view_.AllOps()[i];
|
|
|
|
|
|
|
|
|
|
// redirect the input to the latest version of cache_var
|
|
|
|
|
for (auto* node : op->inputs) {
|
|
|
|
|
if (node->Name() == var) {
|
|
|
|
|
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
|
|
|
|
|
var_nodes_[cache_var].emplace_back(cache_node);
|
|
|
|
|
|
|
|
|
|
// swap node to cache_node
|
|
|
|
|
cache_node->outputs.insert(cache_node->outputs.end(),
|
|
|
|
|
node->outputs.begin(), node->outputs.end());
|
|
|
|
|
PADDLE_ENFORCE(node->inputs.size() == 1 && node->inputs[0]->IsOp());
|
|
|
|
|
auto* prev_op = node->inputs[0];
|
|
|
|
|
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
|
|
|
|
|
cache_node);
|
|
|
|
|
cache_node->inputs.emplace_back(prev_op);
|
|
|
|
|
for (auto* next_op : node->outputs) {
|
|
|
|
|
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
|
|
|
|
|
cache_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// release unused var in graph. Because python side memory optimize
|
|
|
|
|
// may reused the var in same name, so we only clear the var node
|
|
|
|
|
// after current inplaced index.
|
|
|
|
|
graph->RemoveNode(node);
|
|
|
|
|
auto& nodes = var_nodes_.at(var);
|
|
|
|
|
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if we need to rename the output,
|
|
|
|
|
// always create a newer version of cache_var
|
|
|
|
|
for (auto* node : op->outputs) {
|
|
|
|
|
if (node->Name() == var) {
|
|
|
|
|
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
|
|
|
|
|
var_nodes_[cache_var].emplace_back(cache_node);
|
|
|
|
|
// swap node to cache node
|
|
|
|
|
cache_node->outputs.insert(cache_node->outputs.end(),
|
|
|
|
|
node->outputs.begin(), node->outputs.end());
|
|
|
|
|
cache_node->inputs.emplace_back(op);
|
|
|
|
|
std::replace(op->outputs.begin(), op->outputs.end(), node, cache_node);
|
|
|
|
|
for (auto* next_op : node->outputs) {
|
|
|
|
|
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
|
|
|
|
|
cache_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// release unsed var in graph
|
|
|
|
|
graph->RemoveNode(node);
|
|
|
|
|
auto& nodes = var_nodes_.at(var);
|
|
|
|
|
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
|
|
|
|
|
ir::Graph* graph) const {
|
|
|
|
@ -413,22 +353,23 @@ void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NOTE(dzhwinter):
|
|
|
|
|
// two stage commit of inplaced process. if after inplace happens generate a
|
|
|
|
|
// circle,
|
|
|
|
|
// then withdraw the changes. Otherwise, safely add the node.
|
|
|
|
|
auto swap_nodes =
|
|
|
|
|
TryInplaceModifyVar(out_var_name, in_var_name, idx, graph);
|
|
|
|
|
|
|
|
|
|
// NOTE(dzhwinter):
|
|
|
|
|
// two stage commit of inplaced op. If add such node generate a circle,
|
|
|
|
|
// then withdraw the changes. Otherwise, safely add the node.
|
|
|
|
|
if (!ir::HasCircle(*graph)) {
|
|
|
|
|
VLOG(3) << string::Sprintf("!!! %s, %s => %s inplaced", op->Name(),
|
|
|
|
|
out_var_name, in_var_name);
|
|
|
|
|
CommitModify(swap_nodes, graph);
|
|
|
|
|
InplaceModifyDesc(out_var_name, in_var_name, idx);
|
|
|
|
|
CommitModify(swap_nodes, graph);
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << string::Sprintf(
|
|
|
|
|
"Skiped pair %s => %s, inplace will generate a circle. withdraw %s",
|
|
|
|
|
out_var_name, in_var_name, op->Name());
|
|
|
|
|
WithDrawModify(swap_nodes, graph);
|
|
|
|
|
WithdrawModify(swap_nodes, graph);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|