|
|
|
@ -58,8 +58,6 @@ void AddTensorToBlockDesc(framework::proto::BlockDesc* block,
|
|
|
|
|
using inference::analysis::SetAttr;
|
|
|
|
|
|
|
|
|
|
TEST(TensorRTEngineOp, manual) {
|
|
|
|
|
FLAGS_tensorrt_engine_batch_size = 2;
|
|
|
|
|
FLAGS_tensorrt_max_batch_size = 2;
|
|
|
|
|
framework::ProgramDesc program;
|
|
|
|
|
auto* block_ = program.Proto()->add_blocks();
|
|
|
|
|
block_->set_idx(0);
|
|
|
|
@ -101,6 +99,8 @@ TEST(TensorRTEngineOp, manual) {
|
|
|
|
|
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
|
|
|
|
|
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
|
|
|
|
|
block_->SerializeAsString());
|
|
|
|
|
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", 2);
|
|
|
|
|
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10);
|
|
|
|
|
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
|
|
|
|
|
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
|
|
|
|
|
std::vector<std::string>({}));
|
|
|
|
@ -129,8 +129,6 @@ TEST(TensorRTEngineOp, manual) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
|
|
|
|
|
FLAGS_tensorrt_engine_batch_size = batch_size;
|
|
|
|
|
FLAGS_tensorrt_max_batch_size = batch_size;
|
|
|
|
|
framework::ProgramDesc program;
|
|
|
|
|
framework::Scope scope;
|
|
|
|
|
platform::CUDAPlace place;
|
|
|
|
@ -195,8 +193,8 @@ void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
|
|
|
|
|
|
|
|
|
|
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
|
|
|
|
|
block_->SerializeAsString());
|
|
|
|
|
SetAttr<int>(engine_op_desc.Proto(), "max_batch", batch_size);
|
|
|
|
|
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 2 << 10);
|
|
|
|
|
SetAttr<int>(engine_op_desc.Proto(), "max_batch_size", batch_size);
|
|
|
|
|
SetAttr<int>(engine_op_desc.Proto(), "workspace_size", 2 << 10);
|
|
|
|
|
SetAttr<std::vector<std::string>>(
|
|
|
|
|
engine_op_desc.Proto(), "parameters",
|
|
|
|
|
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
|
|
|
|
|