|
|
|
@ -337,6 +337,34 @@ ExtractInputAndOutputOfSubGraph(std::vector<Node *> &graph) { // NOLINT
|
|
|
|
|
std::vector<Node *>(outputs.begin(), outputs.end()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void FilterRedundantOutputOfSubGraph(DataFlowGraph *graph) {
|
|
|
|
|
std::vector<Node *> op_nodes;
|
|
|
|
|
for (auto &node : GraphTraits<DataFlowGraph>(graph).nodes_in_TS()) {
|
|
|
|
|
if (node.type() == Node::Type::kValue || node.deleted()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
op_nodes.push_back(&node);
|
|
|
|
|
}
|
|
|
|
|
size_t op_num = op_nodes.size();
|
|
|
|
|
for (size_t i = 0; i < op_num; i++) {
|
|
|
|
|
if (op_nodes[i]->type() == Node::Type::kFunction) continue;
|
|
|
|
|
std::unordered_set<std::string> follow_up_input_names;
|
|
|
|
|
for (size_t j = i + 1; j < op_num; j++) {
|
|
|
|
|
for (auto *in : op_nodes[j]->inlinks) {
|
|
|
|
|
follow_up_input_names.insert(in->name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<Node *> filtered_subgraph_outlinks;
|
|
|
|
|
for (auto *out : op_nodes[i]->outlinks) {
|
|
|
|
|
if (follow_up_input_names.count(out->name())) {
|
|
|
|
|
filtered_subgraph_outlinks.push_back(out);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_GE(filtered_subgraph_outlinks.size(), 1UL);
|
|
|
|
|
op_nodes[i]->outlinks = filtered_subgraph_outlinks;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace analysis
|
|
|
|
|
} // namespace inference
|
|
|
|
|
} // namespace paddle
|
|
|
|
|