|
|
|
@ -28,6 +28,8 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
|
|
|
|
|
DECLARE_int32(tensorrt_engine_batch_size);
|
|
|
|
|
DECLARE_int32(tensorrt_max_batch_size);
|
|
|
|
|
DECLARE_int32(tensorrt_workspace_size);
|
|
|
|
|
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
@ -54,8 +56,10 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t>& shape) {
|
|
|
|
|
"TensorRT' tensor input requires at least 2 dimensions");
|
|
|
|
|
PADDLE_ENFORCE_LE(shape.size(), 4UL,
|
|
|
|
|
"TensorRT' tensor input requires at most 4 dimensions");
|
|
|
|
|
PADDLE_ENFORCE_EQ(shape.size(), 4UL);
|
|
|
|
|
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
|
|
|
|
|
PADDLE_ENFORCE(shape.size() == 4UL || shape.size() == 2UL);
|
|
|
|
|
if (shape.size() == 4UL)
|
|
|
|
|
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
|
|
|
|
|
return nvinfer1::DimsCHW(shape[1], 1, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace
|
|
|
|
@ -95,7 +99,7 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto input_names = context.op().Inputs("Xs");
|
|
|
|
|
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
|
|
|
|
|
PADDLE_ENFORCE_LE(FLAGS_tensorrt_engine_batch_size,
|
|
|
|
|
context.Attr<int>("max_batch"));
|
|
|
|
|
FLAGS_tensorrt_max_batch_size);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> output_maps =
|
|
|
|
|
context.Attr<std::vector<std::string>>("output_name_mapping");
|
|
|
|
@ -132,7 +136,12 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
|
|
|
|
|
nvinfer1::ITensor* trt_t = engine->GetITensor(output_maps[output_index]);
|
|
|
|
|
auto dims = trt_t->getDimensions();
|
|
|
|
|
// Use the output ITensor's dims to reshape the Fluid Tensor.
|
|
|
|
|
std::vector<int> ddim(dims.d, dims.d + dims.nbDims);
|
|
|
|
|
// The ITensor doesn't contain the batch size dim.
|
|
|
|
|
std::vector<int> ddim;
|
|
|
|
|
ddim.push_back(FLAGS_tensorrt_engine_batch_size);
|
|
|
|
|
for (int i = 0; i < dims.nbDims; i++) {
|
|
|
|
|
ddim.push_back(dims.d[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* fluid_v = context.scope().FindVar(y);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(fluid_v, "no output variable called %s", y);
|
|
|
|
@ -168,8 +177,8 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Get the ProgramDesc and pass to convert.
|
|
|
|
|
framework::proto::BlockDesc block_desc;
|
|
|
|
|
block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
|
|
|
|
|
int max_batch = context.Attr<int>("max_batch");
|
|
|
|
|
auto max_workspace = context.Attr<int>("max_workspace");
|
|
|
|
|
int max_batch = FLAGS_tensorrt_max_batch_size;
|
|
|
|
|
auto max_workspace = FLAGS_tensorrt_workspace_size;
|
|
|
|
|
auto params = context.Attr<std::vector<std::string>>("parameters");
|
|
|
|
|
std::unordered_set<std::string> parameters;
|
|
|
|
|
for (const auto& param : params) {
|
|
|
|
|