|
|
|
@ -40,7 +40,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
|
|
|
|
|
op->SetOutput(output.first, {output.second});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
struct IsReachable {
|
|
|
|
|
struct TestIsReachable {
|
|
|
|
|
using func = std::function<bool(const std::string&, const std::string&)>;
|
|
|
|
|
|
|
|
|
|
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
|
|
|
|
@ -89,7 +89,9 @@ struct IsReachable {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
|
|
|
|
|
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph,
|
|
|
|
|
int expected_conv_count,
|
|
|
|
|
int expected_elementwise_add_count = 0) {
|
|
|
|
|
int conv_count = 0;
|
|
|
|
|
int elementwise_add_count = 0;
|
|
|
|
|
|
|
|
|
@ -101,8 +103,8 @@ void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
|
|
|
|
|
++elementwise_add_count;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
EXPECT_EQ(conv_count, 1);
|
|
|
|
|
EXPECT_EQ(elementwise_add_count, 0);
|
|
|
|
|
EXPECT_EQ(conv_count, expected_conv_count);
|
|
|
|
|
EXPECT_EQ(elementwise_add_count, expected_elementwise_add_count);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
|
|
|
|
@ -127,22 +129,13 @@ ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
|
|
|
|
|
|
|
|
|
|
return prog;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
|
|
|
|
|
auto prog =
|
|
|
|
|
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "conv2d",
|
|
|
|
|
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "b"});
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
|
|
|
|
void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
|
|
|
|
|
const std::string& to, int expected_conv_num) {
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(*prog));
|
|
|
|
|
|
|
|
|
|
IsReachable is_reachable;
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
|
|
|
|
TestIsReachable is_reachable;
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)(from, to));
|
|
|
|
|
|
|
|
|
|
auto pass =
|
|
|
|
|
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
|
|
|
@ -150,82 +143,87 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
|
|
|
|
|
graph = pass->Apply(std::move(graph));
|
|
|
|
|
int current_nodes_num = graph->Nodes().size();
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)(from, to));
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
|
|
|
|
current_nodes_num);
|
|
|
|
|
|
|
|
|
|
AssertOpsCount(graph);
|
|
|
|
|
AssertOpsCount(graph, expected_conv_num);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass,
|
|
|
|
|
ConvolutionWithElementwiseAddReluNoBias) {
|
|
|
|
|
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
|
|
|
|
|
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "b"});
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
|
|
|
|
|
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
|
|
|
|
|
|
|
|
|
|
IsReachable is_reachable;
|
|
|
|
|
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
|
|
|
|
|
SetOp(&prog, "conv2d",
|
|
|
|
|
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "c"});
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
|
|
|
|
|
|
|
|
|
auto pass =
|
|
|
|
|
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
|
|
|
|
int original_nodes_num = graph->Nodes().size();
|
|
|
|
|
graph = pass->Apply(std::move(graph));
|
|
|
|
|
int current_nodes_num = graph->Nodes().size();
|
|
|
|
|
RunPassAndAssert(&prog, "a", "relu", 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass,
|
|
|
|
|
ConvolutionAsYWithElementwiseAddReluNoBias) {
|
|
|
|
|
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
|
|
|
|
current_nodes_num);
|
|
|
|
|
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
|
|
|
|
|
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "c"});
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
|
|
|
|
|
|
|
|
|
AssertOpsCount(graph);
|
|
|
|
|
RunPassAndAssert(&prog, "a", "relu", 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
|
|
|
|
|
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
|
|
|
|
|
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
|
|
|
|
|
SetOp(&prog, "conv2d",
|
|
|
|
|
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "b"});
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
|
|
|
|
|
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "c"});
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
|
|
|
|
|
|
|
|
|
IsReachable is_reachable;
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "d"));
|
|
|
|
|
RunPassAndAssert(&prog, "a", "relu", 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto pass =
|
|
|
|
|
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
|
|
|
|
int original_nodes_num = graph->Nodes().size();
|
|
|
|
|
graph = pass->Apply(std::move(graph));
|
|
|
|
|
int current_nodes_num = graph->Nodes().size();
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass,
|
|
|
|
|
ConvolutionAsXWithElementwiseAddReluNoBias) {
|
|
|
|
|
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
|
|
|
|
|
|
|
|
|
|
EXPECT_FALSE(is_reachable(graph)("a", "d"));
|
|
|
|
|
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
|
|
|
|
|
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "c"});
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
|
|
|
|
current_nodes_num);
|
|
|
|
|
AssertOpsCount(graph);
|
|
|
|
|
RunPassAndAssert(&prog, "a", "relu", 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
|
|
|
|
|
TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
|
|
|
|
|
auto prog =
|
|
|
|
|
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
|
|
|
|
|
BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"});
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
|
|
|
|
|
SetOp(&prog, "conv2d",
|
|
|
|
|
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
|
|
|
|
|
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "c"});
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
|
|
|
|
SetOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}},
|
|
|
|
|
{"Output", "e"});
|
|
|
|
|
|
|
|
|
|
IsReachable is_reachable;
|
|
|
|
|
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, {"Out", "f"});
|
|
|
|
|
SetOp(&prog, "relu", {{"X", "f"}}, {"Out", "g"});
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "f"));
|
|
|
|
|
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
|
|
|
|
|
|
|
|
|
TestIsReachable is_reachable;
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "g"));
|
|
|
|
|
|
|
|
|
|
auto pass =
|
|
|
|
|
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
|
|
|
|
@ -233,11 +231,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
|
|
|
|
|
graph = pass->Apply(std::move(graph));
|
|
|
|
|
int current_nodes_num = graph->Nodes().size();
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "f"));
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "g"));
|
|
|
|
|
EXPECT_EQ(original_nodes_num, current_nodes_num);
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
|
|
|
|
|
current_nodes_num);
|
|
|
|
|
AssertOpsCount(graph);
|
|
|
|
|
AssertOpsCount(graph, 2, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
|