|
|
|
@ -85,16 +85,13 @@ TEST(ConvReLUFusePass, basic) {
|
|
|
|
|
|
|
|
|
|
for (auto* node : graph->Nodes()) {
|
|
|
|
|
if (node->IsOp() && node->Op()->Type() == "conv2d") {
|
|
|
|
|
if (node->Op()->HasAttr("use_mkldnn")) {
|
|
|
|
|
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
|
|
|
|
|
if (use_mkldnn) {
|
|
|
|
|
if (node->Op()->HasAttr("fuse_relu")) {
|
|
|
|
|
bool fuse_relu = boost::get<bool>(node->Op()->GetAttr("fuse_relu"));
|
|
|
|
|
if (fuse_relu) {
|
|
|
|
|
++conv_relu_count;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto* op = node->Op();
|
|
|
|
|
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
|
|
|
|
|
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
|
|
|
|
|
ASSERT_TRUE(op->HasAttr("fuse_relu"));
|
|
|
|
|
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
|
|
|
|
|
if (fuse_relu) {
|
|
|
|
|
++conv_relu_count;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|