From e823ce68bbf7a0f86e2dc45120efa8cc155390d4 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Wed, 8 Aug 2018 13:26:19 +0000 Subject: [PATCH] filter redundant output --- .../inference/analysis/data_flow_graph.cc | 28 +++++++++++++++++++ .../inference/analysis/data_flow_graph.h | 1 + .../analysis/data_flow_graph_to_fluid_pass.cc | 1 + 3 files changed, 30 insertions(+) diff --git a/paddle/fluid/inference/analysis/data_flow_graph.cc b/paddle/fluid/inference/analysis/data_flow_graph.cc index 8a3af0a8eb..7f64bc75ae 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph.cc @@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector &graph) { // NOLINT std::vector(outputs.begin(), outputs.end())); } +void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) { + std::vector op_nodes; + for (auto &node : GraphTraits(graph).nodes_in_TS()) { + if (node.type() == Node::Type::kValue || node.deleted()) { + continue; + } + op_nodes.push_back(&node); + } + size_t op_num = op_nodes.size(); + for (size_t i = 0; i < op_num; i++) { + if (op_nodes[i]->type() == Node::Type::kFunction) continue; + std::unordered_set follow_up_input_names; + for (size_t j = i + 1; j < op_num; j++) { + for (auto *in : op_nodes[j]->inlinks) { + follow_up_input_names.insert(in->name()); + } + } + std::vector filtered_subgraph_outlinks; + for (auto *out : op_nodes[i]->outlinks) { + if (follow_up_input_names.count(out->name())) { + filtered_subgraph_outlinks.push_back(out); + } + } + PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL); + op_nodes[i]->outlinks = filtered_subgraph_outlinks; + } +} + } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph.h b/paddle/fluid/inference/analysis/data_flow_graph.h index 16aeae4d35..bb3ec6bbc1 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph.h +++ b/paddle/fluid/inference/analysis/data_flow_graph.h @@ -178,6 +178,7 @@ struct GraphTraits { std::pair, std::vector> ExtractInputAndOutputOfSubGraph(std::vector &graph); // NOLINT +void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph); } // namespace analysis } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc index 2328d87042..7365c826a8 100644 --- a/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc +++ b/paddle/fluid/inference/analysis/data_flow_graph_to_fluid_pass.cc @@ -52,6 +52,7 @@ bool DataFlowGraphToFluidPass::Initialize(Argument *argument) { bool DataFlowGraphToFluidPass::Finalize() { return true; } void DataFlowGraphToFluidPass::Run(DataFlowGraph *graph) { + FilterRedundantOutputOfSubGraph(graph); LOG(INFO) << "graph.inputs " << graph->inputs.size(); for (auto &node : GraphTraits(graph).nodes_in_TS()) { if (node.deleted()) continue;