|
|
|
@ -142,10 +142,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
|
|
|
|
|
LOG_FIRST_N(INFO, 1) << "The TRT engine: " << engine_key_
|
|
|
|
|
<< " is running calibration trt int8... ";
|
|
|
|
|
int runtime_batch = 1;
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(dev_place);
|
|
|
|
|
auto stream =
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
|
|
|
|
|
if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_key_)) {
|
|
|
|
|
TRTCalibratorEngine *calib_res =
|
|
|
|
|
Singleton<TRTCalibratorEngineManager>::Global().Create(engine_key_);
|
|
|
|
@ -162,10 +158,10 @@ class TensorRTEngineOp : public framework::OperatorBase {
|
|
|
|
|
calib_buffers, runtime_batch, engine_key_, dev_place));
|
|
|
|
|
calib_res->thr_.reset(new std::thread([&]() {
|
|
|
|
|
calib_res->engine_.reset(
|
|
|
|
|
new TensorRTEngine(max_batch_size_, workspace_size_, stream,
|
|
|
|
|
enable_int8_, calib_res->calib_.get()));
|
|
|
|
|
new TensorRTEngine(max_batch_size_, workspace_size_, enable_int8_,
|
|
|
|
|
calib_res->calib_.get()));
|
|
|
|
|
VLOG(3) << "start the calib trt engine thread";
|
|
|
|
|
Prepare(scope, dev_place, calib_res->engine_.get());
|
|
|
|
|
Prepare(scope, calib_res->engine_.get());
|
|
|
|
|
}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -253,22 +249,17 @@ class TensorRTEngineOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LE(runtime_batch, max_batch_size_);
|
|
|
|
|
// Execute the engine.
|
|
|
|
|
engine->Execute(runtime_batch, buffers);
|
|
|
|
|
engine->Execute(runtime_batch, &buffers, stream);
|
|
|
|
|
cudaStreamSynchronize(stream);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TensorRTEngine *GetEngine(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const {
|
|
|
|
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
|
|
|
|
auto &dev_ctx = *pool.Get(dev_place);
|
|
|
|
|
auto stream =
|
|
|
|
|
reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx).stream();
|
|
|
|
|
if (trt_engine_.get() == nullptr) {
|
|
|
|
|
trt_engine_.reset(new TensorRTEngine(max_batch_size_, workspace_size_,
|
|
|
|
|
stream, enable_int8_,
|
|
|
|
|
calibrator_.get()));
|
|
|
|
|
enable_int8_, calibrator_.get()));
|
|
|
|
|
if (true) {
|
|
|
|
|
Prepare(scope, dev_place, trt_engine_.get());
|
|
|
|
|
Prepare(scope, trt_engine_.get());
|
|
|
|
|
} else {
|
|
|
|
|
// create static engine
|
|
|
|
|
}
|
|
|
|
@ -276,20 +267,19 @@ class TensorRTEngineOp : public framework::OperatorBase {
|
|
|
|
|
return trt_engine_.get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Prepare(const framework::Scope &scope, const platform::Place &dev_place,
|
|
|
|
|
TensorRTEngine *engine) const {
|
|
|
|
|
void Prepare(const framework::Scope &scope, TensorRTEngine *engine) const {
|
|
|
|
|
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
|
|
|
|
|
"kernel etc). This process may cost a lot of time.";
|
|
|
|
|
framework::proto::BlockDesc block_desc;
|
|
|
|
|
block_desc.ParseFromString(Attr<std::string>("subgraph"));
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> output_maps =
|
|
|
|
|
Attr<std::vector<std::string>>("output_name_mapping");
|
|
|
|
|
framework::BlockDesc block(nullptr /*programdesc*/, &block_desc);
|
|
|
|
|
|
|
|
|
|
engine->InitNetwork();
|
|
|
|
|
|
|
|
|
|
framework::BlockDesc block(nullptr /*programdesc*/, &block_desc);
|
|
|
|
|
VLOG(4) << "parsed var size " << block.AllVars().size();
|
|
|
|
|
std::vector<std::string> output_maps =
|
|
|
|
|
Attr<std::vector<std::string>>("output_name_mapping");
|
|
|
|
|
|
|
|
|
|
// Add inputs
|
|
|
|
|
VLOG(4) << "declare inputs";
|
|
|
|
|
for (auto &input : Inputs("Xs")) {
|
|
|
|
@ -306,12 +296,12 @@ class TensorRTEngineOp : public framework::OperatorBase {
|
|
|
|
|
PADDLE_ENFORCE(var, "no variable called %s", input);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
|
|
|
|
|
"TensorRT engine only takes LoDTensor as input");
|
|
|
|
|
|
|
|
|
|
engine->DeclareInput(
|
|
|
|
|
input, FluidDataType2TRT(
|
|
|
|
|
var->Proto()->type().lod_tensor().tensor().data_type()),
|
|
|
|
|
Vec2TRT_Dims(t_shape));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inference::Singleton<inference::tensorrt::OpConverter>::Global()
|
|
|
|
|
.ConvertBlock(block_desc, param_names_, scope, engine);
|
|
|
|
|
|
|
|
|
|