|
|
|
@ -83,16 +83,29 @@ void analysis::TensorRtSubgraphPass::ApplyImpl(
|
|
|
|
|
|
|
|
|
|
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
|
|
|
|
|
const std::set<std::string> &engine_outputs,
|
|
|
|
|
const std::string &predictor_id) {
|
|
|
|
|
const std::string &predictor_id,
|
|
|
|
|
const std::string &max_batch_size,
|
|
|
|
|
const std::string &precision,
|
|
|
|
|
const std::string &use_calib_mode) {
|
|
|
|
|
std::string engine_hash_key = "";
|
|
|
|
|
for (auto name : engine_inputs) {
|
|
|
|
|
engine_hash_key += name;
|
|
|
|
|
engine_hash_key += "#";
|
|
|
|
|
}
|
|
|
|
|
for (auto name : engine_outputs) {
|
|
|
|
|
engine_hash_key += name;
|
|
|
|
|
engine_hash_key += "#";
|
|
|
|
|
}
|
|
|
|
|
engine_hash_key += predictor_id;
|
|
|
|
|
engine_hash_key += "#";
|
|
|
|
|
engine_hash_key += max_batch_size;
|
|
|
|
|
engine_hash_key += "#";
|
|
|
|
|
engine_hash_key += precision;
|
|
|
|
|
engine_hash_key += "#";
|
|
|
|
|
engine_hash_key += use_calib_mode;
|
|
|
|
|
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
|
|
|
|
|
VLOG(2) << "TRT engine hash key: " << engine_hash_key;
|
|
|
|
|
VLOG(2) << "TRT engine key: " << engine_key;
|
|
|
|
|
return engine_key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -245,8 +258,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
// TODO(NHZlX)
|
|
|
|
|
// There are models with the same structure but the different parameters,
|
|
|
|
|
// when running in the 'use_serialize' mode, there is a bug.
|
|
|
|
|
auto engine_key = GenerateEngineKey(input_names_with_id, output_names_with_id,
|
|
|
|
|
std::to_string(0));
|
|
|
|
|
auto engine_key = GenerateEngineKey(
|
|
|
|
|
input_names_with_id, output_names_with_id, std::to_string(0),
|
|
|
|
|
std::to_string(Get<int>("max_batch_size")),
|
|
|
|
|
std::to_string(static_cast<int>(precision_mode)),
|
|
|
|
|
std::to_string(static_cast<int>(use_calib_mode)));
|
|
|
|
|
auto predictor_id = Get<int>("predictor_id");
|
|
|
|
|
|
|
|
|
|
// Get "" when there is no cached calibration table data.
|
|
|
|
@ -359,6 +375,9 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
|
|
|
|
|
GetTrtEngineSerializedPath(Get<std::string>("model_opt_cache_dir"),
|
|
|
|
|
engine_key),
|
|
|
|
|
trt_engine_serialized_data);
|
|
|
|
|
LOG(INFO) << "Save TRT Optimized Info to "
|
|
|
|
|
<< GetTrtEngineSerializedPath(
|
|
|
|
|
Get<std::string>("model_opt_cache_dir"), engine_key);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|