fea/infer memory optim2 (#14953)

inference-pre-release-gpu
Yan Chunwei 7 years ago committed by GitHub
parent 6597ccb01f
commit 885c4e57ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/fc_fuse_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {

@ -18,8 +18,10 @@ limitations under the License. */
#include <fstream>
#include <iosfwd>
#include <ostream>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph_traits.h"
DEFINE_string(print_sub_graph_dir, "",
"FLAGS_print_sub_graph_dir is used "
@ -41,7 +43,7 @@ void SortHelper(
}
}
VLOG(3) << "topology sort insert: " << node->Name()
VLOG(5) << "topology sort insert: " << node->Name() << " "
<< reinterpret_cast<void *>(node) << " input " << node->inputs.size();
ret->push_back(node);
}
@ -99,12 +101,13 @@ std::vector<ir::Node *> TopologySortOperations(const Graph &graph) {
return ret;
}
// Build operator inlink edge table.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
for (auto &n : graph.Nodes()) {
if (n->NodeType() != ir::Node::Type::kOperation) continue;
if (!n->IsOp()) continue;
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::unordered_set<ir::Node *>();
}
@ -121,6 +124,119 @@ std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
return adj_list;
}
// Build operator outlink edge table.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationOutAdjList(
const Graph &graph) {
std::map<ir::Node *, std::unordered_set<ir::Node *>> adj_list;
for (auto &n : graph.Nodes()) {
if (!n->IsOp()) continue;
if (adj_list.find(n) == adj_list.end()) {
adj_list[n] = std::unordered_set<ir::Node *>();
}
for (auto &var : n->outputs) {
for (auto &adj_n : var->outputs) {
PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation);
VLOG(40) << "adj " << adj_n->Name() << reinterpret_cast<void *>(adj_n)
<< " -> " << n->Name() << reinterpret_cast<void *>(n)
<< " via " << var->Name() << reinterpret_cast<void *>(var);
adj_list[n].insert(adj_n);
}
}
}
return adj_list;
}
std::vector<ir::Node *> OpDFSSort(const Graph &graph) {
auto edge_table = BuildOperationOutAdjList(graph);
std::stack<Node *> stack;
for (auto &ele : edge_table) {
if (ele.first->inputs.empty()) {
// find the input ops (those without input vars)
stack.push(ele.first);
} else {
// find the ops with only persistable vars as inputs.
bool all_persistable = true;
for (auto *input : ele.first->inputs) {
if (!(input->IsVar() && input->Var() && input->Var()->Persistable())) {
all_persistable = false;
}
}
if (all_persistable) {
stack.push(ele.first);
}
}
}
std::vector<Node *> res;
// start from the feed op and DFS
std::unordered_set<Node *> unique_set;
while (!stack.empty()) {
// will start from the last feed by default.
auto cur = stack.top();
stack.pop();
unique_set.insert(cur);
res.push_back(cur);
for (auto *op : edge_table[cur]) {
if (!unique_set.count(op)) {
stack.push(op);
}
}
}
return res;
}
std::vector<ir::Node *> TopologyDfsSortOperations(const Graph &graph) {
std::vector<ir::Node *> nodes;
std::unordered_map<Node *, int> in_degree;
auto set_out_ops_ready = [&](Node *var) {
for (auto *op : var->outputs) {
--in_degree[op];
}
};
// build in_degree
for (auto *node : graph.Nodes()) {
if (node->IsOp()) {
in_degree[node] += node->inputs.size();
} else if (node->IsVar() && node->inputs.empty()) {
// put all the inputs of the whole graph ready.
set_out_ops_ready(node);
}
}
std::deque<Node *> op_queue;
// first visit
for (auto &node : OpDFSSort(graph)) {
if (node->IsOp()) {
op_queue.push_back(node);
}
}
// traverse the graph
int num_ops = op_queue.size();
while (num_ops) {
for (auto it = op_queue.begin(); it != op_queue.end(); it++) {
auto *&cur_op = *it;
if (!cur_op || in_degree[cur_op] > 0) continue;
// visit this node
// put all the output var of this op valid.
for (auto *out_var : cur_op->outputs) {
if (!out_var) continue;
set_out_ops_ready(out_var);
}
VLOG(8) << "visit " << cur_op->Name();
nodes.push_back(cur_op);
cur_op = nullptr;
num_ops--;
}
}
return nodes;
}
size_t GraphNum(const Graph &graph) {
std::unordered_set<ir::Node *> nodes(graph.Nodes());
std::unordered_set<ir::Node *> visited_nodes;
@ -203,6 +319,29 @@ size_t GraphNum(const Graph &graph) {
return graph_count;
}
void CleanIndividualNodes(Graph *graph) {
std::unordered_set<Node *> nodes2rm;
for (auto *node : graph->Nodes()) {
if (node->inputs.empty() && node->outputs.empty()) {
nodes2rm.insert(node);
}
}
for (auto *node : nodes2rm) {
graph->RemoveNode(node);
}
}
std::vector<Node *> TopologyVarientSort(const Graph &graph,
SortKind sort_kind) {
switch (sort_kind) {
case SortKind::TS:
return framework::ir::TopologySortOperations(graph);
default:
return framework::ir::TopologyDfsSortOperations(graph);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle

@ -34,6 +34,23 @@ size_t GraphNum(const Graph &graph);
// `graph` cannot contain circle.
std::vector<ir::Node *> TopologySortOperations(const Graph &graph);
// Topological sort, but try to DFS.
std::vector<ir::Node *> TopologyDfsSortOperations(const Graph &graph);
// Different kinds to sort the operators in a graph to a sequence.
enum class SortKind {
// Topological Search
TS = 0,
// Topological and Depth First Search
TDFS
};
// Several kinds of topological sort.
std::vector<Node *> TopologyVarientSort(const Graph &graph, SortKind sort_kind);
// Clean the nodes that doesn't connect to others.
void CleanIndividualNodes(Graph *graph);
// Build an adjacency list of operations for the `graph`.
std::map<ir::Node *, std::unordered_set<ir::Node *>> BuildOperationAdjList(
const Graph &graph);

@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
@ -29,6 +28,14 @@ namespace ir {
std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
std::unique_ptr<Graph> graph) const {
// Remove the unneeded variables after memory optimization.
std::unordered_set<std::string> vars2remove;
if (graph->Has(kGraphToProgramVarsToRemove)) {
vars2remove = graph->Get<std::unordered_set<std::string>>(
kGraphToProgramVarsToRemove);
VLOG(2) << "graph to program remove " << vars2remove.size() << " nodes";
}
ProgramDesc& program = Get<ProgramDesc>("program");
std::unique_ptr<proto::ProgramDesc> program_pb(
@ -40,25 +47,35 @@ std::unique_ptr<Graph> GraphToProgramPass::ApplyImpl(
std::unordered_set<std::string> visited_vars;
for (ir::Node* n : graph->Nodes()) {
if (n->IsVar()) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0) {
if (n->Var() && visited_vars.count(n->Var()->Name()) == 0 &&
!vars2remove.count(n->Var()->Name())) {
visited_vars.insert(n->Var()->Name());
block->add_vars()->MergeFrom(*n->Var()->Proto());
}
}
}
block->clear_ops();
std::vector<ir::Node*> nodes = TopologySortOperations(*graph);
std::vector<ir::Node*> nodes;
if (Has(kGraphToProgramSortKind)) {
// Inference Memory Optimize relays on this branch.
int sort_kind = Get<int>(kGraphToProgramSortKind);
nodes = TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(sort_kind));
} else {
nodes = TopologySortOperations(*graph);
}
for (ir::Node* n : nodes) {
if (!n->Op()) {
continue;
}
if (!n->Op()) continue;
block->add_ops()->MergeFrom(*n->Op()->Proto());
}
program.CopyFrom(*program_pb);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle

@ -20,6 +20,10 @@ namespace paddle {
namespace framework {
namespace ir {
const char kGraphToProgramVarsToRemove[] =
"__graph_to_program_vars_to_remove__";
const char kGraphToProgramSortKind[] = "__graph_to_program_sort_kind__";
class GraphToProgramPass : public Pass {
protected:
std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const override;

@ -135,4 +135,4 @@ GraphVizPass::marked_nodes_t GraphVizPass::ConsumeMarkedNodes(
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);

@ -64,7 +64,7 @@ class Node {
std::string Name() const { return name_; }
VarDesc* Var() {
VarDesc* Var() const {
PADDLE_ENFORCE(IsVar());
return var_desc_.get();
}

@ -50,8 +50,8 @@ void NaiveExecutor::Run() {
"running Paddle Inference";
#endif // PADDLE_ON_INFERENCE
for (auto &op : ops_) {
VLOG(3) << std::this_thread::get_id() << " run " << op->Type()
<< " on scope " << scope_;
VLOG(4) << std::this_thread::get_id() << " run "
<< op->DebugStringEx(scope_) << " on scope " << scope_;
op->SetIsCalledByExecutor(false);
op->Run(*scope_, place_);
}
@ -69,10 +69,12 @@ void NaiveExecutor::CreateVariables(const ProgramDesc &desc, int block_id,
anc = anc->parent();
}
int num_vars = 0;
for (auto &var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
num_vars++;
if (persistable == var->Persistable()) {
if (persistable) {
@ -90,6 +92,7 @@ void NaiveExecutor::CreateVariables(const ProgramDesc &desc, int block_id,
}
}
}
VLOG(4) << "naive executor create " << num_vars << " vars";
}
void NaiveExecutor::CreateOps(const ProgramDesc &desc, int block_id,

@ -18,6 +18,7 @@ cc_library(analysis SRCS
analyzer.cc
analysis_pass
DEPS ${analysis_deps} analysis_helper
${INFER_IR_PASSES}
)
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)

@ -15,8 +15,8 @@
#include "paddle/fluid/inference/analysis/analyzer.h"
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h"
#include "paddle/fluid/inference/analysis/passes/passes.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace inference {
@ -24,13 +24,16 @@ namespace analysis {
Analyzer::Analyzer() {}
void Analyzer::Run(Argument *argument) { RunIrAnalysis(argument); }
void Analyzer::Run(Argument *argument) { RunAnalysis(argument); }
void Analyzer::RunIrAnalysis(Argument *argument) {
std::vector<std::string> passes({"ir_analysis_compose_pass"});
for (auto &pass : passes) {
PassRegistry::Global().Retreive(pass)->Run(argument);
void Analyzer::RunAnalysis(Argument *argument) {
PADDLE_ENFORCE(argument->analysis_passes_valid(),
"analsis_passes is not valid in the argument.");
for (auto &pass : argument->analysis_passes()) {
string::PrettyLogH1("--- Running analysis [%s]", pass);
auto *ptr = PassRegistry::Global().Retreive(pass);
PADDLE_ENFORCE_NOT_NULL(ptr, "no analysis pass called %s", pass);
ptr->Run(argument);
}
}

@ -54,7 +54,7 @@ class Analyzer final {
DISABLE_COPY_AND_ASSIGN(Analyzer);
protected:
void RunIrAnalysis(Argument* argument);
void RunAnalysis(Argument* argument);
};
} // namespace analysis

@ -32,6 +32,8 @@ TEST(Analyzer, analysis_without_tensorrt) {
argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass"});
Analyzer analyser;
analyser.Run(&argument);
@ -44,6 +46,8 @@ TEST(Analyzer, analysis_with_tensorrt) {
argument.SetModelDir(FLAGS_inference_model_dir);
argument.SetIrAnalysisPasses({"infer_clean_graph_pass"});
argument.SetUseGPU(false);
argument.SetAnalysisPasses({"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass"});
Analyzer analyser;
analyser.Run(&argument);

@ -110,16 +110,20 @@ struct Argument {
// The overall Scope to work on.
DECL_ARGUMENT_UNIQUE_FIELD(scope, Scope, framework::Scope);
// The default program, loaded from disk.
DECL_ARGUMENT_UNIQUE_FIELD(main_program, MainProgram, framework::ProgramDesc);
// The ir passes to perform in analysis phase.
DECL_ARGUMENT_FIELD(ir_analysis_passes, IrAnalysisPasses,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(analysis_passes, AnalysisPasses,
std::vector<std::string>);
// Pass a set of op types to enable its mkldnn kernel
DECL_ARGUMENT_FIELD(mkldnn_enabled_op_types, MKLDNNEnabledOpTypes,
std::unordered_set<std::string>);
// Passed from config.
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
@ -127,6 +131,13 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_workspace_size, TensorRtWorkspaceSize, int);
DECL_ARGUMENT_FIELD(tensorrt_min_subgraph_size, TensorRtMinSubgraphSize, int);
// Memory optimized related.
DECL_ARGUMENT_FIELD(enable_memory_optim, EnableMemoryOptim, bool);
DECL_ARGUMENT_FIELD(memory_optim_force_update, MemoryOptimForceUpdate, bool);
// Indicate which kind of sort algorithm is used for operators, the memory
// optimization relays on the sort algorithm.
DECL_ARGUMENT_FIELD(memory_optim_sort_kind, MemoryOptimSortKind, int);
// The program transformed by IR analysis phase.
DECL_ARGUMENT_UNIQUE_FIELD(ir_analyzed_program, IrAnalyzedProgram,
framework::proto::ProgramDesc);

@ -28,6 +28,13 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/port.h"
#ifdef _WIN32
#define GCC_ATTRIBUTE(attr__) ;
#else
#define GCC_ATTRIBUTE(attr__) __attribute__((attr__));
#endif
#define __SHOULD_USE_RESULT__ GCC_ATTRIBUTE(warn_unused_result)
namespace paddle {
namespace inference {
namespace analysis {

@ -83,6 +83,7 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
PADDLE_ENFORCE(graph.get());
// Apply all the passes
for (const auto &pass : passes_) {
if (pass->Type() == "graph_viz_pass") continue;
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
graph = pass->Apply(std::move(graph));
}

@ -1,6 +1,6 @@
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS proto_desc)
if (TENSORRT_FOUND)
if (WITH_GPU AND TENSORRT_FOUND)
cc_library(tensorrt_subgraph_pass SRCS tensorrt_subgraph_pass.cc DEPS subgraph_detector tensorrt_op_teller)
set(analysis_deps ${analysis_deps}

@ -413,7 +413,6 @@ void SubGraphFuser::ReplaceNodesWithSubGraphs() {
auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)();
for (auto &subgraph : subgraphs) {
if (subgraph.size() <= (size_t)min_subgraph_size_) continue;
LOG(INFO) << "detect a subgraph size " << subgraph.size();
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
// replace this sub-graph with the first node. Two steps: 1. Create a Block
// Node that contains this subgraph 2. Mark the nodes inside the sub-graph

@ -21,6 +21,7 @@
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace inference {
@ -77,6 +78,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
framework::BlockDesc block_desc(nullptr, &block_proto);
block_desc.Proto()->set_parent_idx(-1);
block_desc.Proto()->set_idx(0);
string::PrettyLogDetail("--- detect a sub-graph with %d nodes",
subgraph.size());
for (auto *node : subgraph) {
auto *op = block_desc.AppendOp();
*op->Proto() = *node->Op()->Proto();

@ -1,11 +1,18 @@
cc_library(ir_graph_build_pass SRCS ir_graph_build_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(ir_analysis_pass SRCS ir_analysis_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(memory_optim_pass SRCS memory_optimize_pass.cc DEPS analysis_pass)
cc_library(ir_params_sync_among_devices_pass SRCS ir_params_sync_among_devices_pass.cc DEPS analysis_pass argument ir_pass_manager)
cc_library(analysis_passes SRCS passes.cc DEPS ir_graph_build_pass ir_analysis_pass ir_params_sync_among_devices_pass)
cc_library(ir_graph_to_program_pass SRCS ir_graph_to_program_pass.cc DEPS analysis_pass graph_to_program_pass)
cc_library(analysis_passes SRCS passes.cc DEPS
ir_graph_build_pass
ir_analysis_pass
ir_params_sync_among_devices_pass
memory_optim_pass
ir_graph_to_program_pass
)
set(analysis_deps ${analysis_deps}
ir_graph_build_pass
ir_analysis_pass
analysis_passes
subgraph_detector
CACHE INTERNAL "")

@ -1,62 +0,0 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_analysis_compose_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace inference {
namespace analysis {
void IrAnalysisComposePass::RunImpl(Argument *argument) {
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
ApplyIrPasses(argument);
CollectFusionStatis(argument);
}
std::string IrAnalysisComposePass::repr() const {
return "ir-analysis-compose-pass";
}
void IrAnalysisComposePass::ApplyIrPasses(Argument *argument) {
std::vector<std::string> passes({
"ir_graph_build_pass", "ir_analysis_pass",
"ir_params_sync_among_devices_pass",
});
for (const auto &pass : passes) {
VLOG(2) << "Run pass " << pass;
auto *the_pass = PassRegistry::Global().Retreive(pass);
the_pass->Run(argument);
}
}
void IrAnalysisComposePass::CollectFusionStatis(Argument *argument) {
if (!argument->main_graph().Has(framework::ir::kFuseStatisAttr)) {
LOG(INFO) << "argument has no fuse statis";
return;
}
argument->SetFusionStatis(
argument->main_graph().Get<Argument::fusion_statis_t>(
framework::ir::kFuseStatisAttr));
}
} // namespace analysis
} // namespace inference
} // namespace paddle

@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_analysis_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
namespace paddle {
@ -31,9 +32,18 @@ void IrAnalysisPass::RunImpl(Argument* argument) {
IRPassManager the_ir_manager(argument);
graph = the_ir_manager.Apply(std::move(graph));
PADDLE_ENFORCE_GT(graph->Nodes().size(), 0);
argument->SetIrAnalyzedProgram(new framework::proto::ProgramDesc(
the_ir_manager.AcquireProgram(&graph, argument->main_program())));
argument->SetMainGraph(graph.release());
CollectFusionStatis(argument);
}
void IrAnalysisPass::CollectFusionStatis(Argument* argument) {
if (!argument->main_graph().Has(framework::ir::kFuseStatisAttr)) {
LOG(INFO) << "argument has no fuse statis";
return;
}
argument->SetFusionStatis(
argument->main_graph().Get<Argument::fusion_statis_t>(
framework::ir::kFuseStatisAttr));
}
std::string IrAnalysisPass::repr() const { return "ir-analysis-pass"; }

@ -29,6 +29,9 @@ namespace analysis {
class IrAnalysisPass : public AnalysisPass {
public:
void RunImpl(Argument* argument) override;
void CollectFusionStatis(Argument* argument);
std::string repr() const override;
};

@ -0,0 +1,45 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace inference {
namespace analysis {
void IrGraphToProgramPass::RunImpl(Argument *argument) {
auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
if (argument->memory_optim_sort_kind_valid()) {
pass->Set(framework::ir::kGraphToProgramSortKind,
new int(argument->memory_optim_sort_kind()));
}
std::unique_ptr<Graph> graph(argument->main_graph_ptr());
framework::ProgramDesc desc(argument->main_program());
pass->SetNotOwned("program", &desc);
auto thegraph = pass->Apply(std::move(graph));
thegraph.release(); // the argument still own the graph.
argument->SetIrAnalyzedProgram(
new framework::proto::ProgramDesc(*desc.Proto()));
}
} // namespace analysis
} // namespace inference
} // namespace paddle

@ -14,31 +14,17 @@
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/analysis_pass.h"
#include "paddle/fluid/inference/analysis/passes/passes.h"
namespace paddle {
namespace inference {
namespace analysis {
/*
* The analysis pass to run a list of IR passes (like a function call).
* Currently, it should be the first pass of analysis phase.
*/
class IrAnalysisComposePass : public AnalysisPass {
class IrGraphToProgramPass : public AnalysisPass {
public:
void RunImpl(Argument* argument) override;
std::string repr() const override;
void RunImpl(Argument *argument) override;
private:
void ApplyIrPasses(Argument* argument);
void CollectFusionStatis(Argument* argument);
// Assign a Scope for IR passes to modify the weights.
void AssignScopeToModify(Argument* argument);
std::string repr() const override { return "ir-graph-to-param-pass"; }
};
} // namespace analysis

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save