|
|
@ -63,7 +63,6 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
|
|
|
|
void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
|
|
|
|
void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
|
|
|
|
Graph *graph) const {
|
|
|
|
Graph *graph) const {
|
|
|
|
auto *op_desc = node->Op();
|
|
|
|
auto *op_desc = node->Op();
|
|
|
|
static int counter{0};
|
|
|
|
|
|
|
|
auto &subgraph = *Agent(node).subgraph();
|
|
|
|
auto &subgraph = *Agent(node).subgraph();
|
|
|
|
PADDLE_ENFORCE(!subgraph.empty());
|
|
|
|
PADDLE_ENFORCE(!subgraph.empty());
|
|
|
|
|
|
|
|
|
|
|
@ -192,8 +191,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
|
|
|
|
block_desc.Proto()->SerializeAsString());
|
|
|
|
block_desc.Proto()->SerializeAsString());
|
|
|
|
SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size"));
|
|
|
|
SetAttr(op_desc->Proto(), "max_batch_size", Get<int>("max_batch_size"));
|
|
|
|
SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size"));
|
|
|
|
SetAttr(op_desc->Proto(), "workspace_size", Get<int>("workspace_size"));
|
|
|
|
SetAttr(op_desc->Proto(), "engine_uniq_key",
|
|
|
|
|
|
|
|
"trt-" + std::to_string(counter++));
|
|
|
|
|
|
|
|
SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes()));
|
|
|
|
SetAttr(op_desc->Proto(), "parameters", ExtractParameters(graph->Nodes()));
|
|
|
|
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
|
|
|
|
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
|
|
|
|
}
|
|
|
|
}
|
|
|
|