|
|
@ -100,6 +100,7 @@ static inline ir::Node* GetNextCascadeInplacedVar(ir::Node* var) {
|
|
|
|
|
|
|
|
|
|
|
|
static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) {
|
|
|
|
static inline ir::Node* GetPrevCascadeInplacedVar(ir::Node* var) {
|
|
|
|
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
|
|
|
|
PADDLE_ENFORCE(var && var->IsVar() && !var->IsCtrlVar());
|
|
|
|
|
|
|
|
if (var->inputs.empty()) return nullptr;
|
|
|
|
auto* prev_op = var->inputs.at(0);
|
|
|
|
auto* prev_op = var->inputs.at(0);
|
|
|
|
auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(),
|
|
|
|
auto input_it = std::find_if(prev_op->inputs.begin(), prev_op->inputs.end(),
|
|
|
|
[&](ir::Node* node) {
|
|
|
|
[&](ir::Node* node) {
|
|
|
@ -165,12 +166,6 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
|
|
|
|
view_.Build(graph.get());
|
|
|
|
view_.Build(graph.get());
|
|
|
|
InitSSAGraphNodes();
|
|
|
|
InitSSAGraphNodes();
|
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
|
|
|
|
|
|
|
|
constexpr char graph_path1[] = "ir_graph_before_inplaced.txt";
|
|
|
|
|
|
|
|
std::unique_ptr<std::ostream> fout1(new std::ofstream(graph_path1));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(fout1->good());
|
|
|
|
|
|
|
|
printer->Print(*graph, *fout1);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto* op : view_.AllOps()) {
|
|
|
|
for (auto* op : view_.AllOps()) {
|
|
|
|
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
|
|
|
|
if (FLAGS_enable_inplace_whitelist && !whitelist_.count(op->Name()))
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
@ -178,10 +173,6 @@ std::unique_ptr<ir::Graph> InplacePass::ApplyImpl(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
graph->ResolveHazard(var_nodes_);
|
|
|
|
graph->ResolveHazard(var_nodes_);
|
|
|
|
|
|
|
|
|
|
|
|
constexpr char graph_path[] = "ir_graph_inplaced.txt";
|
|
|
|
|
|
|
|
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(fout->good());
|
|
|
|
|
|
|
|
printer->Print(*graph, *fout);
|
|
|
|
|
|
|
|
return graph;
|
|
|
|
return graph;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -291,6 +282,7 @@ void InplacePass::WithdrawModify(const SSANodePair& nodes,
|
|
|
|
|
|
|
|
|
|
|
|
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
|
|
|
|
void InplacePass::TryInplaceOpInputOutput(ir::Node* op,
|
|
|
|
ir::Graph* graph) const {
|
|
|
|
ir::Graph* graph) const {
|
|
|
|
|
|
|
|
VLOG(4) << "Try to inplace op " << op->Name();
|
|
|
|
PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
|
|
|
|
PADDLE_ENFORCE(op->Op() != nullptr && op->Op()->Block() != nullptr,
|
|
|
|
"op_desc is nullptr");
|
|
|
|
"op_desc is nullptr");
|
|
|
|
// 4 pre-requirments need to meet if the op want to inplaced.
|
|
|
|
// 4 pre-requirments need to meet if the op want to inplaced.
|
|
|
|