@ -22,6 +22,7 @@ limitations under the License. */
# include <vector>
# include "paddle/fluid/framework/tensor.h"
# include "paddle/fluid/framework/tensor_util.h"
# include "paddle/fluid/inference/api/paddle_analysis_config.h"
# include "paddle/fluid/inference/engine.h"
# include "paddle/fluid/inference/tensorrt/helper.h"
# include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
@ -61,12 +62,14 @@ class TensorRTEngine {
nvinfer1 : : Weights w_ ;
} ;
TensorRTEngine ( int max_batch , int max_workspace , bool enable_int8 = false ,
TRTInt8Calibrator * calibrator = nullptr , int device_id = 0 ,
nvinfer1 : : ILogger & logger = NaiveLogger : : Global ( ) )
TensorRTEngine (
int max_batch , int max_workspace ,
AnalysisConfig : : Precision precision = AnalysisConfig : : Precision : : kFloat32 ,
TRTInt8Calibrator * calibrator = nullptr , int device_id = 0 ,
nvinfer1 : : ILogger & logger = NaiveLogger : : Global ( ) )
: max_batch_ ( max_batch ) ,
max_workspace_ ( max_workspace ) ,
enable_int8_( enable_int8 ) ,
precision_( precision ) ,
calibrator_ ( calibrator ) ,
device_id_ ( device_id ) ,
logger_ ( logger ) { }
@ -168,7 +171,7 @@ class TensorRTEngine {
// the max memory size the engine uses
int max_workspace_ ;
bool enable_int8 _;
AnalysisConfig : : Precision precision _;
TRTInt8Calibrator * calibrator_ ;
// batch size of the current data, will be updated each Executation.
int batch_size_ { - 1 } ;
@ -231,12 +234,12 @@ class TRTEngineManager {
return engines_ . at ( name ) . get ( ) ;
}
TensorRTEngine * Create ( std : : string name , int max_batch , int max_workspace ,
bool enable_int8 = fals e,
TRTInt8Calibrator * calibrator = nullptr ,
int device_id = 0 ,
nvinfer1 : : ILogger & logger = NaiveLogger : : Global ( ) ) {
auto * p = new TensorRTEngine ( max_batch , max_workspace , enable_int8 ,
TensorRTEngine * Create (
std : : string name , int max_batch , int max_workspac e,
AnalysisConfig : : Precision precision = AnalysisConfig : : Precision : : kFloat32 ,
TRTInt8Calibrator * calibrator = nullptr , int device_id = 0 ,
nvinfer1 : : ILogger & logger = NaiveLogger : : Global ( ) ) {
auto * p = new TensorRTEngine ( max_batch , max_workspace , precision ,
calibrator , device_id , logger ) ;
engines_ [ name ] . reset ( p ) ;
return p ;