fix trt delete_pass bug. (#28763)

musl/fix_failed_unittests_in_musl
Wilber 5 years ago committed by GitHub
parent 1dad8ceaab
commit a22ea652cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -175,12 +175,20 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#undef CP_MEMBER
// Update();
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so just copy the passes.
pass_builder_->ClearPasses();
for (const std::string &pass : other.pass_builder()->AllPasses()) {
pass_builder_->AppendPass(pass);
Update();
if (use_tensorrt_) {
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so we just remove the
// deleted_pass.
auto all_passes = kTRTSubgraphPasses;
auto other_passes = other.pass_builder()->AllPasses();
std::vector<std::string> deleted_passes;
std::set_difference(all_passes.begin(), all_passes.end(),
other_passes.begin(), other_passes.end(),
std::inserter(deleted_passes, deleted_passes.begin()));
for (auto ps : deleted_passes) {
pass_builder_->DeletePass(ps);
}
}
}

@ -77,4 +77,18 @@ TEST(paddle_inference_api, UpdateDllFlag) {
LOG(INFO) << e.what();
}
}
TEST(paddle_inference_api, AnalysisConfigCopyCtor) {
AnalysisConfig cfg1;
cfg1.EnableUseGpu(10);
cfg1.EnableTensorRtEngine();
std::string delete_pass("skip_layernorm_fuse_pass");
cfg1.pass_builder()->DeletePass(delete_pass);
AnalysisConfig cfg2(cfg1);
auto passes = cfg2.pass_builder()->AllPasses();
for (auto ps : passes) {
CHECK_NE(ps, delete_pass);
}
}
} // namespace paddle

Loading…
Cancel
Save