|
|
|
@ -69,55 +69,59 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& var : op->outputs) {
|
|
|
|
|
if (!NodeCanReused(var) || cfg_->Use(op).count(var->Name()) == 0 ||
|
|
|
|
|
skip_set_.count(var->Name()))
|
|
|
|
|
if (var->IsVar() && !var->IsCtrlVar() && skip_set_.count(var->Name())) {
|
|
|
|
|
VLOG(3) << "Skip set contains variable of " << var->Name()
|
|
|
|
|
<< "disable reuse on it. skipped";
|
|
|
|
|
continue;
|
|
|
|
|
ir::Node* cache = pool_.FindBestFitNode(var);
|
|
|
|
|
|
|
|
|
|
if (var->Name() == FLAGS_memory_optimize_debug) {
|
|
|
|
|
VLOG(3) << "start match var " << DebugString(var) << " of op "
|
|
|
|
|
<< op->Name();
|
|
|
|
|
VLOG(3) << pool_.ToString();
|
|
|
|
|
VLOG(3) << "matched in pool : "
|
|
|
|
|
<< ((cache == nullptr) ? "False" : "True");
|
|
|
|
|
}
|
|
|
|
|
if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) {
|
|
|
|
|
ir::Node* cache = pool_.FindBestFitNode(var);
|
|
|
|
|
while (cache != nullptr && var->Name() == cache->Name()) {
|
|
|
|
|
VLOG(3) << "The same cache variable is cascade reused. "
|
|
|
|
|
<< cache->Name() << " is re-filled to the pool after "
|
|
|
|
|
<< "the reused op is finished. Current op can not "
|
|
|
|
|
<< "replace it again. Skip this candidate.";
|
|
|
|
|
cache = pool_.FindNextBestFitNode(var, cache);
|
|
|
|
|
}
|
|
|
|
|
if (var->Name() == FLAGS_memory_optimize_debug) {
|
|
|
|
|
VLOG(3) << "start match var " << DebugString(var) << " of op "
|
|
|
|
|
<< op->Name();
|
|
|
|
|
VLOG(3) << pool_.ToString();
|
|
|
|
|
VLOG(3) << "matched in pool : "
|
|
|
|
|
<< ((cache == nullptr) ? "False" : "True");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cache == nullptr) continue;
|
|
|
|
|
if (var->Name() == cache->Name()) {
|
|
|
|
|
VLOG(3) << "The same cache variable is cascade reused." << var->Name()
|
|
|
|
|
<< " is re-filled to the pool after"
|
|
|
|
|
<< "the reused op is finished. Current op can not "
|
|
|
|
|
<< "replace it again. Skip this candidate.";
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
|
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
|
|
|
|
|
VLOG(3) << string::Sprintf(
|
|
|
|
|
"!!! %s, %s => %s, cache idx %d, pool size %d",
|
|
|
|
|
std::to_string(reuse_id++), DebugString(var), DebugString(cache),
|
|
|
|
|
node_idx_in_pool, static_cast<int>(pool_.size()));
|
|
|
|
|
|
|
|
|
|
// update CFG Graph on the fly.
|
|
|
|
|
// reused var maybe re-fill into the pool
|
|
|
|
|
cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx);
|
|
|
|
|
// NOTE(dzhwinter): we need to both update the ProgramDesc
|
|
|
|
|
// and IR Graph. because op_desc/var_desc is used in CreateOp,
|
|
|
|
|
// CreateVar when running happens. But IR Graph
|
|
|
|
|
// define the dependence relationship between nodes.
|
|
|
|
|
RenameVarInGraphDesc(var->Name(), cache->Name(), idx);
|
|
|
|
|
RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get());
|
|
|
|
|
|
|
|
|
|
pool_.Erase(cache);
|
|
|
|
|
}
|
|
|
|
|
if (cache != nullptr) {
|
|
|
|
|
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
|
|
|
|
|
VLOG(3) << string::Sprintf(
|
|
|
|
|
"!!! %s, %s => %s, cache idx %d, pool size %d",
|
|
|
|
|
std::to_string(reuse_id++), DebugString(var), DebugString(cache),
|
|
|
|
|
node_idx_in_pool, static_cast<int>(pool_.size()));
|
|
|
|
|
// NOTE(dzhwinter): update the ProgramDesc/IR Graph
|
|
|
|
|
// and the CFG Graph on the fly.
|
|
|
|
|
//
|
|
|
|
|
// IR Graph define the dependence relationship between nodes.
|
|
|
|
|
//
|
|
|
|
|
// ProgramDesc defines the input/output vars. Its used in
|
|
|
|
|
// CreateOp, CreateVar when running happens.
|
|
|
|
|
//
|
|
|
|
|
// CFG Graph store the liveness information, when reuse happens
|
|
|
|
|
// we also need to update the variable liveness.
|
|
|
|
|
const std::string var_name = var->Name();
|
|
|
|
|
const std::string cache_name = cache->Name();
|
|
|
|
|
|
|
|
|
|
// fill the pool
|
|
|
|
|
std::unordered_set<std::string> unlived_vars;
|
|
|
|
|
for (auto var : cfg_->LiveIn(op)) {
|
|
|
|
|
if (cfg_->LiveOut(op).count(var) == 0) {
|
|
|
|
|
unlived_vars.emplace(var);
|
|
|
|
|
cfg_->RenameVarInCFGGraph(var_name, cache_name, idx);
|
|
|
|
|
RenameVarInGraphDesc(var_name, cache_name, idx);
|
|
|
|
|
RenameVarInGraphNode(var_name, cache_name, idx, graph.get());
|
|
|
|
|
pool_.Erase(cache_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto var : unlived_vars) {
|
|
|
|
|
}
|
|
|
|
|
// fill the pool
|
|
|
|
|
for (auto var : cfg_->LiveIn(op)) {
|
|
|
|
|
if (cfg_->LiveOut(op).count(var) == 0) {
|
|
|
|
|
ir::Node* var_node = cfg_->GetNodeByName(var, op);
|
|
|
|
|
if (var_node == nullptr || var_node->IsCtrlVar()) continue;
|
|
|
|
|
if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
|
|
|
|
|
pool_.Insert(var_node);
|
|
|
|
|
}
|
|
|
|
@ -273,8 +277,7 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
|
|
|
|
|
// redirect the input to the latest version of cache_var
|
|
|
|
|
for (auto* node : op->inputs) {
|
|
|
|
|
if (node->Name() == var) {
|
|
|
|
|
ir::Node* cache_node = graph->CreateVarNode(var_desc.get());
|
|
|
|
|
var_nodes_[cache_var].emplace_back(cache_node);
|
|
|
|
|
ir::Node* cache_node = var_nodes_[cache_var].back();
|
|
|
|
|
|
|
|
|
|
// swap node to cache_node
|
|
|
|
|
cache_node->outputs.insert(cache_node->outputs.end(),
|
|
|
|
@ -283,11 +286,15 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
|
|
|
|
|
auto* prev_op = node->inputs[0];
|
|
|
|
|
std::replace(prev_op->outputs.begin(), prev_op->outputs.end(), node,
|
|
|
|
|
cache_node);
|
|
|
|
|
cache_node->inputs.emplace_back(prev_op);
|
|
|
|
|
for (auto* next_op : node->outputs) {
|
|
|
|
|
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
|
|
|
|
|
cache_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// erase unused node
|
|
|
|
|
auto& nodes = var_nodes_.at(var);
|
|
|
|
|
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
|
|
|
|
|
graph->RemoveNode(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -307,15 +314,14 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
|
|
|
|
|
std::replace(next_op->inputs.begin(), next_op->inputs.end(), node,
|
|
|
|
|
cache_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// erase unused node
|
|
|
|
|
auto& nodes = var_nodes_.at(var);
|
|
|
|
|
nodes.erase(std::remove(nodes.begin(), nodes.end(), node), nodes.end());
|
|
|
|
|
graph->RemoveNode(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// release node of unused var in graph
|
|
|
|
|
for (auto* node : var_nodes_[var]) {
|
|
|
|
|
graph->RemoveNode(node);
|
|
|
|
|
}
|
|
|
|
|
var_nodes_.at(var).clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|