[Fix bug] If the pass name is not found, IsCompatible should return false (#28475)

musl/fix_failed_unittests_in_musl
lidanqing 4 years ago committed by GitHub
parent b258caf467
commit 0fc181dbd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -158,7 +158,7 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d_transpose", 0)
.LE("conv2d_transpose", 1)
.EQ("elementwise_add", 0));
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,

@ -326,6 +326,7 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(quant_conv2d_dequant_fuse_pass,
paddle::framework::ir::QuantDequantFusePass);
REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass);
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination(

@ -308,7 +308,7 @@ class PassVersionCheckerRegistrar {
bool IsPassCompatible(const std::string& fuse_pass_name) const {
auto iter = pass_version_checkers_map_.find(fuse_pass_name);
if (iter == pass_version_checkers_map_.end()) {
return true;
return false;
}
return iter->second.IsPassCompatible();
}

@ -57,6 +57,10 @@ TEST(test_operator_version, test_operator_version) {
TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
const std::string fake_op_name{"op_name__"};
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_registered_capability_pass"));
REGISTER_PASS_CAPABILITY(no_bind_pass);
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_bind_pass"));

Loading…
Cancel
Save