|
|
|
@ -12,12 +12,14 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/helper.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_detector.h"
|
|
|
|
|
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace inference {
|
|
|
|
@ -197,10 +199,26 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> ExtractParameters(
|
|
|
|
|
const std::unordered_set<Node *> &nodes) {
|
|
|
|
|
// We can judge whether a variable is a parameter by
|
|
|
|
|
// its presistable property, but sometimes the presistable
|
|
|
|
|
// of the feed op output is true, so we have to identify it.
|
|
|
|
|
std::vector<std::string> feed_outputs;
|
|
|
|
|
for (const auto &node : nodes) {
|
|
|
|
|
if (!node->IsOp()) continue;
|
|
|
|
|
std::string op_type = node->Op()->Type();
|
|
|
|
|
if (op_type == "feed") {
|
|
|
|
|
std::vector<std::string> output_names = node->Op()->OutputArgumentNames();
|
|
|
|
|
std::copy(output_names.begin(), output_names.end(),
|
|
|
|
|
std::back_inserter(feed_outputs));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> parameters;
|
|
|
|
|
for (const auto &node : nodes) {
|
|
|
|
|
if (!node->IsVar()) continue;
|
|
|
|
|
if (node->Var()->Persistable()) {
|
|
|
|
|
if (node->Var()->Persistable() &&
|
|
|
|
|
std::find(feed_outputs.begin(), feed_outputs.end(), node->Name()) ==
|
|
|
|
|
feed_outputs.end()) {
|
|
|
|
|
parameters.push_back(node->Name());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|