|
|
|
@ -46,7 +46,6 @@ namespace details {
|
|
|
|
|
std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
|
|
|
|
|
std::unique_ptr<ir::Graph> graph) const {
|
|
|
|
|
auto nodes = graph->Nodes();
|
|
|
|
|
ClearControlDepVars(graph.get());
|
|
|
|
|
CollectSkipVarsSet(nodes);
|
|
|
|
|
|
|
|
|
|
cfg_.reset(new details::ControlFlowGraph(*graph));
|
|
|
|
@ -79,7 +78,7 @@ std::unique_ptr<ir::Graph> MemoryOptimizePass::ApplyImpl(
|
|
|
|
|
ir::Node* cache = pool_.FindBestFitNode(var);
|
|
|
|
|
while (cache != nullptr && var->Name() == cache->Name()) {
|
|
|
|
|
VLOG(3) << "The same cache variable is cascade reused. "
|
|
|
|
|
<< var->Name() << " is re-filled to the pool after"
|
|
|
|
|
<< cache->Name() << " is re-filled to the pool after "
|
|
|
|
|
<< "the reused op is finished. Current op can not "
|
|
|
|
|
<< "replace it again. Skip this candidate.";
|
|
|
|
|
cache = pool_.FindNextBestFitNode(var, cache);
|
|
|
|
@ -325,32 +324,6 @@ void MemoryOptimizePass::RenameVarInGraphNode(const std::string& var,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemoryOptimizePass::ClearControlDepVars(ir::Graph* graph) const {
|
|
|
|
|
for (auto& op : graph->Nodes()) {
|
|
|
|
|
if (!op->IsOp()) continue;
|
|
|
|
|
{
|
|
|
|
|
auto& nodes = op->inputs;
|
|
|
|
|
nodes.erase(
|
|
|
|
|
std::remove_if(nodes.begin(), nodes.end(),
|
|
|
|
|
[&](ir::Node* var) { return var->IsCtrlVar(); }),
|
|
|
|
|
nodes.end());
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
auto& nodes = op->outputs;
|
|
|
|
|
nodes.erase(
|
|
|
|
|
std::remove_if(nodes.begin(), nodes.end(),
|
|
|
|
|
[&](ir::Node* var) { return var->IsCtrlVar(); }),
|
|
|
|
|
nodes.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& node : graph->Nodes()) {
|
|
|
|
|
if (node->IsCtrlVar()) {
|
|
|
|
|
graph->RemoveNode(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace details
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|