|
|
|
@ -28,8 +28,6 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
DECLARE_int32(tensorrt_engine_batch_size);
|
|
|
|
|
DECLARE_int32(tensorrt_max_batch_size);
|
|
|
|
|
DECLARE_int32(tensorrt_workspace_size);
|
|
|
|
|
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -92,14 +90,14 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
auto engine_name = context.Attr<std::string>("engine_uniq_key");
|
|
|
|
|
int max_batch_size = context.Attr<int>("max_batch_size");
|
|
|
|
|
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
|
|
|
|
|
Prepare(context);
|
|
|
|
|
}
|
|
|
|
|
auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name);
|
|
|
|
|
auto input_names = context.op().Inputs("Xs");
|
|
|
|
|
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
|
|
|
|
|
PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size,
|
|
|
|
|
FLAGS_tensorrt_max_batch_size);
|
|
|
|
|
PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size, max_batch_size);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> output_maps =
|
|
|
|
|
context.Attr<std::vector<std::string>>("output_name_mapping");
|
|
|
|
@ -173,8 +171,9 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Get the ProgramDesc and pass to convert.
|
|
|
|
|
framework::proto::BlockDesc block_desc;
|
|
|
|
|
block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
|
|
|
|
|
int max_batch = FLAGS_tensorrt_max_batch_size;
|
|
|
|
|
auto max_workspace = FLAGS_tensorrt_workspace_size;
|
|
|
|
|
int max_batch_size = context.Attr<int>("max_batch_size");
|
|
|
|
|
int workspace_size = context.Attr<int>("workspace_size");
|
|
|
|
|
|
|
|
|
|
auto params = context.Attr<std::vector<std::string>>("parameters");
|
|
|
|
|
std::unordered_set<std::string> parameters;
|
|
|
|
|
for (const auto& param : params) {
|
|
|
|
@ -186,7 +185,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// TODO(Superjomn) replace this with a different stream
|
|
|
|
|
auto* engine = Singleton<TRT_EngineManager>::Global().Create(
|
|
|
|
|
max_batch, max_workspace, nullptr /*engine hold its own stream*/,
|
|
|
|
|
max_batch_size, workspace_size, nullptr /*engine hold its own stream*/,
|
|
|
|
|
context.Attr<std::string>("engine_uniq_key"),
|
|
|
|
|
boost::get<platform::CUDAPlace>(context.GetPlace()).device);
|
|
|
|
|
|
|
|
|
|