|
|
|
@ -16,8 +16,10 @@
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/executor.h"
|
|
|
|
@ -220,11 +222,11 @@ class TensorRTEngineOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
TensorRTEngine *GetEngine(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const {
|
|
|
|
|
if (trt_engine_.get() == nullptr) {
|
|
|
|
|
if (!trt_engine_) {
|
|
|
|
|
trt_engine_.reset(new inference::tensorrt::TensorRTEngine(
|
|
|
|
|
max_batch_size_, workspace_size_, enable_int8_, calibrator_.get(),
|
|
|
|
|
boost::get<platform::CUDAPlace>(dev_place).device));
|
|
|
|
|
if (engine_serialized_data_.size() > 0) {
|
|
|
|
|
if (!engine_serialized_data_.empty()) {
|
|
|
|
|
trt_engine_->Deserialize(engine_serialized_data_);
|
|
|
|
|
} else {
|
|
|
|
|
PrepareTRTEngine(scope, trt_engine_.get());
|
|
|
|
|