|
|
|
@ -31,15 +31,6 @@
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_helper.h"
|
|
|
|
|
|
|
|
|
|
DEFINE_bool(enable_subgraph_optimize, false,
|
|
|
|
|
"SubGraph also reuse global graph variables, it will reduce the "
|
|
|
|
|
"memory occupation"
|
|
|
|
|
"but a higher risk of memory reuse error. default disabled.");
|
|
|
|
|
DEFINE_string(memory_optimize_debug, "",
|
|
|
|
|
"debug the operator output variable when do the variable reuse."
|
|
|
|
|
"memory reuse pass."
|
|
|
|
|
"only for debug, default disabled.");
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
@ -57,15 +48,6 @@ void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
auto* op_desc = op->Op();
|
|
|
|
|
// some op in graph has no op desc
|
|
|
|
|
if (op_desc == nullptr) continue;
|
|
|
|
|
if (OpHasSubBlock(op_desc)) {
|
|
|
|
|
if (FLAGS_enable_subgraph_optimize) {
|
|
|
|
|
SubGraphOptimize(op_desc);
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << op->Name()
|
|
|
|
|
<< " has subblock, but disable subgraph optimize. skipped.";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto& var : op->outputs) {
|
|
|
|
|
if (var->IsVar() && !var->IsCtrlVar() && skip_set_.count(var->Name())) {
|
|
|
|
@ -82,13 +64,6 @@ void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
<< "replace it again. Skip this candidate.";
|
|
|
|
|
cache = pool_.FindNextBestFitNode(var, cache);
|
|
|
|
|
}
|
|
|
|
|
if (var->Name() == FLAGS_memory_optimize_debug) {
|
|
|
|
|
VLOG(3) << "start match var " << DebugString(var) << " of op "
|
|
|
|
|
<< op->Name();
|
|
|
|
|
VLOG(3) << pool_.ToString();
|
|
|
|
|
VLOG(3) << "matched in pool : "
|
|
|
|
|
<< ((cache == nullptr) ? "False" : "True");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (cache != nullptr) {
|
|
|
|
|
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
|
|
|
|
@ -128,81 +103,6 @@ void MemoryOptimizePass::ApplyImpl(ir::Graph* graph) const {
|
|
|
|
|
graph->ResolveHazard(var_nodes_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemoryOptimizePass::SubGraphOptimize(OpDesc* op_desc) const {
|
|
|
|
|
// conditional block, while op and their grad op
|
|
|
|
|
auto* sub_block_desc =
|
|
|
|
|
AttrReader(op_desc->GetAttrMap()).Get<BlockDesc*>("sub_block");
|
|
|
|
|
|
|
|
|
|
// create a mirror block to construct an IR Graph.
|
|
|
|
|
ProgramDesc prog;
|
|
|
|
|
auto* copy_block = prog.MutableBlock(0);
|
|
|
|
|
for (auto* op : sub_block_desc->AllOps()) {
|
|
|
|
|
auto* copy_op = copy_block->AppendOp();
|
|
|
|
|
copy_op->CopyFrom(*op);
|
|
|
|
|
copy_op->Flush();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto* var : sub_block_desc->AllVars()) {
|
|
|
|
|
auto* copy_var = copy_block->Var(var->Name());
|
|
|
|
|
copy_var->SetDataType(var->GetDataType());
|
|
|
|
|
// only lod tensor can be reused. So ignore the multiple dims case.
|
|
|
|
|
copy_var->SetType(var->GetType());
|
|
|
|
|
copy_var->SetShape(var->GetShape());
|
|
|
|
|
copy_var->SetPersistable(var->Persistable());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ir::Graph sub_graph(prog);
|
|
|
|
|
std::unordered_set<ir::Node*> sub_graph_all_ops;
|
|
|
|
|
FilterVariables(sub_graph.Nodes(), [&](ir::Node* var) {
|
|
|
|
|
// sub_graph_all_ops.emplace(var);
|
|
|
|
|
if (var->IsVar() && !var->IsCtrlVar()) {
|
|
|
|
|
sub_graph_all_ops.emplace(var);
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
int sub_reuse_id = 0;
|
|
|
|
|
// subgraph nodes is unordered, reuse need to follow the desc order.
|
|
|
|
|
// find the right op node through the descs
|
|
|
|
|
for (auto* sub_op_desc : sub_block_desc->AllOps()) {
|
|
|
|
|
ir::Node* sub_op = nullptr;
|
|
|
|
|
for (auto* node : sub_graph_all_ops) {
|
|
|
|
|
if (node->Op() == sub_op_desc) {
|
|
|
|
|
sub_op = node;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(sub_op != nullptr);
|
|
|
|
|
for (auto* var : sub_op->outputs) {
|
|
|
|
|
if (NodeCanReused(var)) {
|
|
|
|
|
ir::Node* cache = pool_.FindBestFitNode(var);
|
|
|
|
|
if (cache != nullptr) {
|
|
|
|
|
if (var->Var()->GetDataType() != cache->Var()->GetDataType()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
int node_idx_in_pool = pool_.GetNodeIndexInPool(cache);
|
|
|
|
|
VLOG(3) << string::Sprintf(
|
|
|
|
|
"!!! %s, %s => %s, cache idx %d, pool size %d",
|
|
|
|
|
std::to_string(sub_reuse_id++), DebugString(var),
|
|
|
|
|
DebugString(cache), node_idx_in_pool,
|
|
|
|
|
static_cast<int>(pool_.size()));
|
|
|
|
|
// NOTE(dzh): subblock is not in IR graph. Modify the block_desc
|
|
|
|
|
// immediately to make the subblock variable reuse strategy take
|
|
|
|
|
// effect. Because it is a single op in graph. No need to
|
|
|
|
|
// update the ir nodes.
|
|
|
|
|
// FIXME(liuwei1031): Graph is not aware of the existence of
|
|
|
|
|
// BlockDescs and ProgramDescs.
|
|
|
|
|
// The operations related to BlockDesc or ProgramDesc should perform
|
|
|
|
|
// on Graph or Node directly!
|
|
|
|
|
sub_op_desc->Rename(var->Name(), cache->Name());
|
|
|
|
|
if (sub_op_desc->Block() != nullptr &&
|
|
|
|
|
sub_op_desc->Block()->HasVar(var->Name())) {
|
|
|
|
|
sub_op_desc->Block()->RemoveVar(var->Name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemoryOptimizePass::CollectSkipVarsSet(ir::Graph* graph) const {
|
|
|
|
|
// fill skip_set_
|
|
|
|
|
PADDLE_ENFORCE(graph->Has(kMemOptSkipVars));
|
|
|
|
|