|
|
|
@ -23,7 +23,7 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace inference {
|
|
|
|
|
|
|
|
|
|
DEFINE_int32(tensorrt_max_batchsize, 300, "TensorRT maximum batch size");
|
|
|
|
|
DEFINE_int32(tensorrt_max_batchsize, 3, "TensorRT maximum batch size");
|
|
|
|
|
DEFINE_int32(tensorrt_workspace_size, 2048, "TensorRT workspace size");
|
|
|
|
|
|
|
|
|
|
namespace analysis {
|
|
|
|
@ -87,27 +87,90 @@ void DataFlowGraphToFluidPass::AddFluidOp(Node *node) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
|
|
|
|
|
const framework::proto::BlockDesc &block) {
|
|
|
|
|
framework::proto::BlockDesc &block) {
|
|
|
|
|
static int counter{0};
|
|
|
|
|
PADDLE_ENFORCE(node->IsFunctionBlock());
|
|
|
|
|
framework::OpDesc desc;
|
|
|
|
|
auto *func = static_cast<FunctionBlock *>(node);
|
|
|
|
|
|
|
|
|
|
// collect inputs
|
|
|
|
|
std::vector<std::string> io;
|
|
|
|
|
std::unordered_set<std::string> input_names;
|
|
|
|
|
for (auto *x : func->inlinks) {
|
|
|
|
|
io.push_back(x->name());
|
|
|
|
|
input_names.insert(x->name());
|
|
|
|
|
}
|
|
|
|
|
desc.SetInput("Xs", io);
|
|
|
|
|
desc.SetInput(
|
|
|
|
|
"Xs", std::vector<std::string>(input_names.begin(), input_names.end()));
|
|
|
|
|
|
|
|
|
|
// collect outputs
|
|
|
|
|
io.clear();
|
|
|
|
|
std::unordered_set<std::string> output_names;
|
|
|
|
|
for (auto *x : func->outlinks) {
|
|
|
|
|
io.push_back(x->name());
|
|
|
|
|
output_names.insert(x->name());
|
|
|
|
|
}
|
|
|
|
|
desc.SetOutput("Ys", io);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> output_temp(output_names.begin(),
|
|
|
|
|
output_names.end());
|
|
|
|
|
desc.SetOutput("Ys", output_temp);
|
|
|
|
|
desc.SetType("tensorrt_engine");
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::string> output_name_map;
|
|
|
|
|
auto subgraph_nodes = func->subgraph;
|
|
|
|
|
|
|
|
|
|
for (int index = 0; index < block.ops_size(); index++) {
|
|
|
|
|
framework::proto::OpDesc *op = block.mutable_ops(index);
|
|
|
|
|
// auto &op = block.mutable_ops(index);
|
|
|
|
|
auto correspond_node = subgraph_nodes[index];
|
|
|
|
|
PADDLE_ENFORCE_EQ(correspond_node->name(), op->type());
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, size_t> var2id;
|
|
|
|
|
for (auto *in_var : correspond_node->inlinks) {
|
|
|
|
|
var2id[in_var->name()] = in_var->id();
|
|
|
|
|
}
|
|
|
|
|
// TODO(zhaolong): add comments
|
|
|
|
|
for (int i = 0; i < op->inputs_size(); i++) {
|
|
|
|
|
framework::proto::OpDesc_Var *in_var = op->mutable_inputs(i);
|
|
|
|
|
// auto &in_var = op->mutable_inputs(i);
|
|
|
|
|
std::vector<std::string> replaced_names;
|
|
|
|
|
for (int k = 0; k < in_var->arguments_size(); k++) {
|
|
|
|
|
std::string arg_value = in_var->arguments(k);
|
|
|
|
|
if (input_names.count(arg_value)) {
|
|
|
|
|
replaced_names.push_back(arg_value);
|
|
|
|
|
} else {
|
|
|
|
|
replaced_names.push_back(arg_value +
|
|
|
|
|
std::to_string(var2id[arg_value]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
in_var->clear_arguments();
|
|
|
|
|
for (size_t k = 0; k < replaced_names.size(); k++) {
|
|
|
|
|
in_var->add_arguments(replaced_names[k]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
var2id.clear();
|
|
|
|
|
for (auto out_var : correspond_node->outlinks) {
|
|
|
|
|
var2id[out_var->name()] = out_var->id();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < op->outputs_size(); i++) {
|
|
|
|
|
framework::proto::OpDesc_Var *out_var = op->mutable_outputs(i);
|
|
|
|
|
std::vector<std::string> replaced_names;
|
|
|
|
|
for (int k = 0; k < out_var->arguments_size(); k++) {
|
|
|
|
|
std::string arg_value = out_var->arguments(k);
|
|
|
|
|
if (output_names.count(arg_value)) {
|
|
|
|
|
output_name_map[arg_value] =
|
|
|
|
|
arg_value + std::to_string(var2id[arg_value]);
|
|
|
|
|
}
|
|
|
|
|
replaced_names.push_back(arg_value + std::to_string(var2id[arg_value]));
|
|
|
|
|
}
|
|
|
|
|
out_var->clear_arguments();
|
|
|
|
|
for (size_t k = 0; k < replaced_names.size(); k++) {
|
|
|
|
|
out_var->add_arguments(replaced_names[k]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::string> output_mapping;
|
|
|
|
|
for (auto name : output_names) {
|
|
|
|
|
PADDLE_ENFORCE(output_name_map.count(name) != 0);
|
|
|
|
|
output_mapping.push_back(output_name_map[name]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(!block.vars().empty(), "the block has no var-desc");
|
|
|
|
|
// Set attrs
|
|
|
|
|
SetAttr(desc.Proto(), "subgraph", block.SerializeAsString());
|
|
|
|
@ -115,6 +178,7 @@ void CreateTrtEngineOp(Node *node, const DataFlowGraph &graph,
|
|
|
|
|
SetAttr(desc.Proto(), "max_batch", FLAGS_tensorrt_max_batchsize);
|
|
|
|
|
SetAttr(desc.Proto(), "max_workspace", FLAGS_tensorrt_workspace_size);
|
|
|
|
|
SetAttr(desc.Proto(), "parameters", ExtractParameters(graph.nodes.nodes()));
|
|
|
|
|
SetAttr(desc.Proto(), "output_name_mapping", output_mapping);
|
|
|
|
|
node->SetPbMsg(desc.Proto()->SerializeAsString());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -146,11 +210,13 @@ void DataFlowGraphToFluidPass::AddEngineOp(Node *node) {
|
|
|
|
|
LOG(INFO) << "transformed variable size: "
|
|
|
|
|
<< block_desc.Proto()->vars().size();
|
|
|
|
|
// copy ops.
|
|
|
|
|
|
|
|
|
|
for (auto *node : block_node->subgraph) {
|
|
|
|
|
auto *op = block_desc.AppendOp();
|
|
|
|
|
PADDLE_ENFORCE(!node->pb_msg().empty());
|
|
|
|
|
op->Proto()->ParseFromString(node->pb_msg());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*block_desc.Proto()->mutable_vars() =
|
|
|
|
|
argument_->origin_program_desc->blocks(0).vars();
|
|
|
|
|
PADDLE_ENFORCE(!block_desc.Proto()->vars().empty());
|
|
|
|
|