|
|
|
@ -17,8 +17,8 @@
|
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/subgraph_detector.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/helper.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
|
|
|
|
|
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
|
|
|
|
|
#include "paddle/fluid/inference/tensorrt/engine.h"
|
|
|
|
@ -40,9 +40,9 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
|
|
|
|
|
return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op());
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
SubGraphFuser fuser(graph, teller,
|
|
|
|
|
Get<int>("min_subgraph_size") /*min subgraph size*/,
|
|
|
|
|
"tensorrt_engine");
|
|
|
|
|
framework::ir::SubGraphFuser fuser(
|
|
|
|
|
graph, teller, Get<int>("min_subgraph_size") /*min subgraph size*/,
|
|
|
|
|
"tensorrt_engine");
|
|
|
|
|
fuser();
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> graph_param_names =
|
|
|
|
@ -52,18 +52,19 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
|
|
|
|
|
std::vector<std::string> repetitive_params;
|
|
|
|
|
|
|
|
|
|
for (auto *node : graph->Nodes()) {
|
|
|
|
|
if (node->IsOp() && !Agent(node).subgraph()->empty()) {
|
|
|
|
|
if (node->IsOp() && !framework::ir::Agent(node).subgraph()->empty()) {
|
|
|
|
|
CreateTensorRTOp(node, graph, graph_param_names, &repetitive_params);
|
|
|
|
|
|
|
|
|
|
std::unordered_set<const Node *> nodes2remove(
|
|
|
|
|
Agent(node).subgraph()->begin(), Agent(node).subgraph()->end());
|
|
|
|
|
framework::ir::Agent(node).subgraph()->begin(),
|
|
|
|
|
framework::ir::Agent(node).subgraph()->end());
|
|
|
|
|
framework::ir::GraphSafeRemoveNodes(graph, nodes2remove);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_set<const Node *> nodes2remove;
|
|
|
|
|
for (auto *node : graph->Nodes()) {
|
|
|
|
|
if (node->IsOp() && Agent(node).deleted()) {
|
|
|
|
|
if (node->IsOp() && framework::ir::Agent(node).deleted()) {
|
|
|
|
|
nodes2remove.insert(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -88,11 +89,11 @@ std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
framework::ir::Node *node, Graph *graph,
|
|
|
|
|
framework::ir::Node *node, framework::ir::Graph *graph,
|
|
|
|
|
const std::vector<std::string> &graph_params,
|
|
|
|
|
std::vector<std::string> *repetitive_params) const {
|
|
|
|
|
auto *op_desc = node->Op();
|
|
|
|
|
auto &subgraph = *Agent(node).subgraph();
|
|
|
|
|
auto &subgraph = *framework::ir::Agent(node).subgraph();
|
|
|
|
|
PADDLE_ENFORCE(!subgraph.empty());
|
|
|
|
|
|
|
|
|
|
framework::ProgramDesc *program_desc =
|
|
|
|
@ -161,7 +162,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
if (precision_mode == AnalysisConfig::Precision::kHalf) enable_fp16 = true;
|
|
|
|
|
auto enable_int8 = Get<bool>("enable_int8");
|
|
|
|
|
auto use_calib_mode = Get<bool>("use_calib_mode");
|
|
|
|
|
auto &subgraph_nodes = *Agent(node).subgraph();
|
|
|
|
|
auto &subgraph_nodes = *framework::ir::Agent(node).subgraph();
|
|
|
|
|
|
|
|
|
|
// The following procedure is used to rename all the intermediate
|
|
|
|
|
// variables and the output variables of the subgraph.
|
|
|
|
|