|
|
|
@ -68,6 +68,19 @@ std::unique_ptr<framework::ir::Graph> analysis::TensorRtSubgraphPass::ApplyImpl(
|
|
|
|
|
return graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string GenerateEngineKey(const std::set<std::string> &engine_inputs,
|
|
|
|
|
const std::set<std::string> &engine_outputs) {
|
|
|
|
|
std::string engine_hash_key = "";
|
|
|
|
|
for (auto name : engine_inputs) {
|
|
|
|
|
engine_hash_key += name;
|
|
|
|
|
}
|
|
|
|
|
for (auto name : engine_outputs) {
|
|
|
|
|
engine_hash_key += name;
|
|
|
|
|
}
|
|
|
|
|
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
|
|
|
|
|
return engine_key;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
|
|
|
|
|
Graph *graph) const {
|
|
|
|
|
auto *op_desc = node->Op();
|
|
|
|
@ -97,7 +110,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
|
|
|
|
|
*op->Proto() = *node->Op()->Proto();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// collect inputs
|
|
|
|
|
// Then, we will use the input_names_with_id and output_names_with_id to
|
|
|
|
|
// generate the eigine key.
|
|
|
|
|
// So, We use set instead of unordered_set here to ensure that the engine key
|
|
|
|
|
// is unique.
|
|
|
|
|
std::set<std::string> input_names;
|
|
|
|
|
std::set<std::string> input_names_with_id;
|
|
|
|
|
for (auto *x : node->inputs) {
|
|
|
|
@ -217,30 +233,13 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
|
|
|
|
|
SetAttr(op_desc->Proto(), "output_name_mapping", output_mapping);
|
|
|
|
|
|
|
|
|
|
auto enable_int8 = Get<bool>("enable_int8");
|
|
|
|
|
SetAttr(op_desc->Proto(), "calibration_data", std::string(""));
|
|
|
|
|
auto engine_key =
|
|
|
|
|
GenerateEngineKey(input_names_with_id, output_names_with_id);
|
|
|
|
|
|
|
|
|
|
// we use the subgraph's inputs and outputs to generate the engine key.
|
|
|
|
|
std::string engine_hash_key = "";
|
|
|
|
|
for (auto name : input_names_with_id) {
|
|
|
|
|
engine_hash_key += name;
|
|
|
|
|
}
|
|
|
|
|
for (auto name : output_names_with_id) {
|
|
|
|
|
engine_hash_key += name;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto engine_key = std::to_string(std::hash<std::string>()(engine_hash_key));
|
|
|
|
|
std::string calibration_data = GetTrtCalibTableData(
|
|
|
|
|
Get<std::string>("model_opt_cache_dir"), engine_key, enable_int8);
|
|
|
|
|
SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
|
|
|
|
|
|
|
|
|
|
auto trt_calib_file =
|
|
|
|
|
GetTrtCalibPath(Get<std::string>("model_dir"), engine_key);
|
|
|
|
|
VLOG(3) << "engine key: " << engine_key;
|
|
|
|
|
if (enable_int8 && FileExists(trt_calib_file)) {
|
|
|
|
|
VLOG(3) << "Calibration table file: " << trt_calib_file << "is found here";
|
|
|
|
|
std::ifstream infile(trt_calib_file, std::ios::in);
|
|
|
|
|
std::stringstream buffer;
|
|
|
|
|
buffer << infile.rdbuf();
|
|
|
|
|
std::string calibration_data(buffer.str());
|
|
|
|
|
SetAttr(op_desc->Proto(), "calibration_data", calibration_data);
|
|
|
|
|
}
|
|
|
|
|
SetAttr(op_desc->Proto(), "enable_int8", enable_int8);
|
|
|
|
|
SetAttr(op_desc->Proto(), "engine_key", engine_key);
|
|
|
|
|
}
|
|
|
|
|