|
|
@ -16,6 +16,10 @@
|
|
|
|
* This file defines TensorRTSubgraphNodeMarkPass which helps to mark the ops
|
|
|
|
* This file defines TensorRTSubgraphNodeMarkPass which helps to mark the ops
|
|
|
|
* that supported by TensorRT engine.
|
|
|
|
* that supported by TensorRT engine.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <string>
|
|
|
|
#include "paddle/fluid/inference/analysis/pass.h"
|
|
|
|
#include "paddle/fluid/inference/analysis/pass.h"
|
|
|
|
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
|
|
|
|
#include "paddle/fluid/inference/analysis/subgraph_splitter.h"
|
|
|
|
|
|
|
|
|
|
|
@ -30,7 +34,8 @@ class TensorRTSubgraphNodeMarkPass : public DataFlowGraphPass {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using teller_t = SubGraphSplitter::NodeInsideSubgraphTeller;
|
|
|
|
using teller_t = SubGraphSplitter::NodeInsideSubgraphTeller;
|
|
|
|
|
|
|
|
|
|
|
|
TensorRTSubgraphNodeMarkPass(const teller_t& teller) : teller_(teller) {}
|
|
|
|
explicit TensorRTSubgraphNodeMarkPass(const teller_t& teller)
|
|
|
|
|
|
|
|
: teller_(teller) {}
|
|
|
|
|
|
|
|
|
|
|
|
bool Initialize(Argument* argument) override { return true; }
|
|
|
|
bool Initialize(Argument* argument) override { return true; }
|
|
|
|
|
|
|
|
|
|
|
@ -38,8 +43,10 @@ class TensorRTSubgraphNodeMarkPass : public DataFlowGraphPass {
|
|
|
|
// sub-graph into TensorRT.
|
|
|
|
// sub-graph into TensorRT.
|
|
|
|
void Run(DataFlowGraph* graph) override;
|
|
|
|
void Run(DataFlowGraph* graph) override;
|
|
|
|
|
|
|
|
|
|
|
|
std::string repr() const { return "tensorrt-sub-subgraph-mark"; }
|
|
|
|
std::string repr() const override { return "tensorrt-sub-subgraph-mark"; }
|
|
|
|
std::string description() const { return "tensorrt sub-graph mark pass"; }
|
|
|
|
std::string description() const override {
|
|
|
|
|
|
|
|
return "tensorrt sub-graph mark pass";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Pass* CreateGraphvizDebugerPass() const override;
|
|
|
|
Pass* CreateGraphvizDebugerPass() const override;
|
|
|
|
bool Finalize() override;
|
|
|
|
bool Finalize() override;
|
|
|
|