|
|
|
@ -69,55 +69,58 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& var : op->outputs) {
|
|
|
|
|
if (!NodeCanReused(var) || cfg_->Use(op).count(var->Name()) == 0 ||
|
|
|
|
|
skip_set_.count(var->Name()))
|
|
|
|
|
if (skip_set_.count(var->Name())) {
|
|
|
|
|
VLOG(3) << "Skip set contains variable of " << var->Name()
|
|
|
|
|
<< "disable reuse on it. skipped";
|
|
|
|
|
continue;
|
|
|
|
|
ir::Node* cache = pool_.FindBestFitNode(var);
|
|
|
|
|
|
|
|
|
|
if (var->Name() == FLAGS_memory_optimize_debug) {
|
|
|
|
|
VLOG(3) << "start match var " << DebugString(var) << " of op "
|
|
|
|
|
<< op->Name();
|
|
|
|
|
VLOG(3) << pool_.ToString();
|
|
|
|
|
VLOG(3) << "matched in pool : "
|
|
|
|
|
<< ((cache == nullptr) ? "False" : "True");
|
|
|
|
|
}
|
|
|
|
|
if (NodeCanReused(var) && cfg_->Use(op).count(var->Name()) == 0) {
|
|
|
|
|
ir::Node* cache = pool_.FindBestFitNode(var);
|
|
|
|
|
if (var->Name() == FLAGS_memory_optimize_debug) {
|
|
|
|
|
VLOG(3) << "start match var " << DebugString(var) << " of op "
|
|
|
|
|
<< op->Name();
|
|
|
|
|
VLOG(3) << pool_.ToString();
|
|
|
|
|
VLOG(3) << "matched in pool : "
|
|
|
|
|
<< ((cache == nullptr) ? "False" : "True");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cache == nullptr) continue;
|
|
|
|
|
if (var->Name() == cache->Name()) {
|
|
|
|
|
VLOG(3) << "The same cache variable is cascade reused." << var->Name()
|
|
|
|
|
<< " is re-filled to the pool after"
|
|
|
|
|
<< "the reused op is finished. Current op can not "
|
|
|
|
|
<< "replace it again. Skip this candidate.";
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
|
|
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
|
|
|
|
|
VLOG(3) << string::Sprintf(
|
|
|
|
|
"!!! %s, %s => %s, cache idx %d, pool size %d",
|
|
|
|
|
std::to_string(reuse_id++), DebugString(var), DebugString(cache),
|
|
|
|
|
node_idx_in_pool, static_cast<int>(pool_.size()));
|
|
|
|
|
|
|
|
|
|
// update CFG Graph on the fly.
|
|
|
|
|
// reused var maybe re-fill into the pool
|
|
|
|
|
cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx);
|
|
|
|
|
// NOTE(dzhwinter): we need to both update the ProgramDesc
|
|
|
|
|
// and IR Graph. because op_desc/var_desc is used in CreateOp,
|
|
|
|
|
// CreateVar when running happens. But IR Graph
|
|
|
|
|
// define the dependence relationship between nodes.
|
|
|
|
|
RenameVarInGraphDesc(var->Name(), cache->Name(), idx);
|
|
|
|
|
RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get());
|
|
|
|
|
if (cache != nullptr) {
|
|
|
|
|
if (var->Name() == cache->Name()) {
|
|
|
|
|
VLOG(3) << "The same cache variable is cascade reused."
|
|
|
|
|
<< var->Name() << " is re-filled to the pool after"
|
|
|
|
|
<< "the reused op is finished. Current op can not "
|
|
|
|
|
<< "replace it again. Skip this candidate.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pool_.Erase(cache);
|
|
|
|
|
}
|
|
|
|
|
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
|
|
|
|
|
VLOG(3) << string::Sprintf(
|
|
|
|
|
"!!! %s, %s => %s, cache idx %d, pool size %d",
|
|
|
|
|
std::to_string(reuse_id++), DebugString(var), DebugString(cache),
|
|
|
|
|
node_idx_in_pool, static_cast<int>(pool_.size()));
|
|
|
|
|
// NOTE(dzhwinter): update the ProgramDesc/IR Graph
|
|
|
|
|
// and the CFG Graph on the fly.
|
|
|
|
|
//
|
|
|
|
|
// IR Graph define the dependence relationship between nodes.
|
|
|
|
|
//
|
|
|
|
|
// ProgramDesc defines the input/output vars. Its used in
|
|
|
|
|
// CreateOp, CreateVar when running happens.
|
|
|
|
|
//
|
|
|
|
|
// CFG Graph store the liveness information, when reuse happens
|
|
|
|
|
// we also need to update the variable liveness.
|
|
|
|
|
cfg_->RenameVarInCFGGraph(var->Name(), cache->Name(), idx);
|
|
|
|
|
RenameVarInGraphDesc(var->Name(), cache->Name(), idx);
|
|
|
|
|
RenameVarInGraphNode(var->Name(), cache->Name(), idx, graph.get());
|
|
|
|
|
|
|
|
|
|
// fill the pool
|
|
|
|
|
std::unordered_set<std::string> unlived_vars;
|
|
|
|
|
for (auto var : cfg_->LiveIn(op)) {
|
|
|
|
|
if (cfg_->LiveOut(op).count(var) == 0) {
|
|
|
|
|
unlived_vars.emplace(var);
|
|
|
|
|
pool_.Erase(cache);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto var : unlived_vars) {
|
|
|
|
|
}
|
|
|
|
|
// fill the pool
|
|
|
|
|
for (auto var : cfg_->LiveIn(op)) {
|
|
|
|
|
if (cfg_->LiveOut(op).count(var) == 0) {
|
|
|
|
|
ir::Node* var_node = cfg_->GetNodeByName(var, op);
|
|
|
|
|
if (var_node == nullptr) continue;
|
|
|
|
|
if (NodeCanReused(var_node) && !pool_.Has(var_node)) {
|
|
|
|
|
pool_.Insert(var_node);
|
|
|
|
|
}
|
|
|
|
|