|
|
|
@ -199,8 +199,12 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
SetAttr(op_desc->Proto(), "parameters", params);
|
|
|
|
|
|
|
|
|
|
auto use_static_engine = Get<bool>("use_static_engine");
|
|
|
|
|
// TODO(NHZlX)
|
|
|
|
|
// There are models with the same structure but the different parameters,
|
|
|
|
|
// when runing in the 'use_serialize' mode, there is a bug.
|
|
|
|
|
auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id,
|
|
|
|
|
std::to_string(0));
|
|
|
|
|
auto predictor_id = Get<int>("predictor_id");
|
|
|
|
|
|
|
|
|
|
// Get "" when there is no cached calibration table data.
|
|
|
|
|
bool load_from_memory = Get<bool>("model_from_memory");
|
|
|
|
@ -214,6 +218,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
|
|
|
|
|
SetAttr(op_desc->Proto(), "use_calib_mode", use_calib_mode);
|
|
|
|
|
SetAttr(op_desc->Proto(), "engine_key", engine_key);
|
|
|
|
|
SetAttr(op_desc->Proto(), "predictor_id", predictor_id);
|
|
|
|
|
std::string trt_engine_serialized_data = "";
|
|
|
|
|
SetAttr(op_desc->Proto(), "engine_serialized_data",
|
|
|
|
|
trt_engine_serialized_data);
|
|
|
|
@ -233,15 +238,20 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
|
|
|
|
|
std::copy(params.begin(), params.end(),
|
|
|
|
|
std::back_inserter(*repetitive_params));
|
|
|
|
|
bool need_serialize = (use_static_engine && !load_from_memory);
|
|
|
|
|
|
|
|
|
|
tensorrt::TensorRTEngine *trt_engine =
|
|
|
|
|
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
|
|
|
|
|
.Create(engine_key + std::to_string(predictor_id),
|
|
|
|
|
Get<int>("max_batch_size"), Get<int>("workspace_size"),
|
|
|
|
|
enable_int8, calibrator.get(), Get<int>("gpu_device_id"));
|
|
|
|
|
|
|
|
|
|
bool need_serialize = (use_static_engine && !load_from_memory);
|
|
|
|
|
if (need_serialize) {
|
|
|
|
|
trt_engine_serialized_data = GetTrtEngineSerializedData(
|
|
|
|
|
Get<std::string>("model_opt_cache_dir"), engine_key);
|
|
|
|
|
// we can load the engine info serialized before from the disk.
|
|
|
|
|
if (!trt_engine_serialized_data.empty()) {
|
|
|
|
|
SetAttr(op_desc->Proto(), "engine_serialized_data",
|
|
|
|
|
trt_engine_serialized_data);
|
|
|
|
|
trt_engine->Deserialize(trt_engine_serialized_data);
|
|
|
|
|
LOG(INFO) << "Load TRT Optimized Info from "
|
|
|
|
|
<< GetTrtEngineSerializedPath(
|
|
|
|
|
Get<std::string>("model_opt_cache_dir"), engine_key);
|
|
|
|
@ -254,10 +264,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
// 2. already load serialized trt engine info.
|
|
|
|
|
LOG(INFO) << "Prepare TRT engine (Optimize model structure, Select OP "
|
|
|
|
|
"kernel etc). This process may cost a lot of time.";
|
|
|
|
|
std::unique_ptr<tensorrt::TensorRTEngine> trt_engine(
|
|
|
|
|
new tensorrt::TensorRTEngine(
|
|
|
|
|
Get<int>("max_batch_size"), Get<int>("workspace_size"), enable_int8,
|
|
|
|
|
calibrator.get(), Get<int>("gpu_device_id")));
|
|
|
|
|
|
|
|
|
|
auto *scope = param_scope();
|
|
|
|
|
framework::BlockDesc block_desc_temp(nullptr, block_desc.Proto());
|
|
|
|
|
std::unordered_set<std::string> param_set(params.begin(), params.end());
|
|
|
|
@ -265,20 +272,18 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
.ConvertBlockToTRTEngine(
|
|
|
|
|
&block_desc_temp, *scope,
|
|
|
|
|
std::vector<std::string>(input_names.begin(), input_names.end()),
|
|
|
|
|
param_set, output_mapping, trt_engine.get());
|
|
|
|
|
param_set, output_mapping, trt_engine);
|
|
|
|
|
|
|
|
|
|
if (need_serialize) {
|
|
|
|
|
nvinfer1::IHostMemory *serialized_engine_data = trt_engine->Serialize();
|
|
|
|
|
trt_engine_serialized_data =
|
|
|
|
|
std::string((const char *)serialized_engine_data->data(),
|
|
|
|
|
serialized_engine_data->size());
|
|
|
|
|
|
|
|
|
|
if (need_serialize) {
|
|
|
|
|
SaveTrtEngineSerializedDataToFile(
|
|
|
|
|
GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
|
|
|
|
|
engine_key),
|
|
|
|
|
trt_engine_serialized_data);
|
|
|
|
|
}
|
|
|
|
|
SetAttr(op_desc->Proto(), "engine_serialized_data",
|
|
|
|
|
trt_engine_serialized_data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace analysis
|
|
|
|
|