|
|
@ -8,6 +8,9 @@ namespace paddle {
|
|
|
|
namespace framework {
|
|
|
|
namespace framework {
|
|
|
|
namespace ir {
|
|
|
|
namespace ir {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
constexpr int nodes_removed = 3;
|
|
|
|
|
|
|
|
constexpr int nodes_added = 1;
|
|
|
|
|
|
|
|
|
|
|
|
void SetOp(ProgramDesc* prog, const std::string& type,
|
|
|
|
void SetOp(ProgramDesc* prog, const std::string& type,
|
|
|
|
const std::vector<std::string>& inputs,
|
|
|
|
const std::vector<std::string>& inputs,
|
|
|
|
const std::vector<std::string>& outputs) {
|
|
|
|
const std::vector<std::string>& outputs) {
|
|
|
@ -93,7 +96,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
|
|
|
|
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
|
|
|
|
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
|
|
|
|
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
|
|
|
|
SetOp(&prog, "relu", {"d"}, {"e"});
|
|
|
|
SetOp(&prog, "relu", {"d"}, {"e"});
|
|
|
|
|
|
|
|
|
|
|
|
return prog;
|
|
|
|
return prog;
|
|
|
@ -113,7 +116,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
|
|
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
|
|
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
|
|
|
|
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
|
|
|
|
// Assert conv_relu op in newly generated graph
|
|
|
|
// Assert conv_relu op in newly generated graph
|
|
|
|
int conv_count = 0;
|
|
|
|
int conv_count = 0;
|
|
|
|
int elementwise_add_count = 0;
|
|
|
|
int elementwise_add_count = 0;
|
|
|
@ -143,7 +146,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
|
|
|
|
SetOp(&prog, "conv2d", {"a", "weights"}, {"b"});
|
|
|
|
SetOp(&prog, "elementwise_add", {"b", "c"}, {"d"});
|
|
|
|
SetOp(&prog, "elementwise_add", {"c", "b"}, {"d"});
|
|
|
|
|
|
|
|
|
|
|
|
return prog;
|
|
|
|
return prog;
|
|
|
|
};
|
|
|
|
};
|
|
|
@ -161,7 +164,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
|
|
|
|
|
|
|
|
|
|
|
|
EXPECT_FALSE(is_reachable(graph)("a", "d"));
|
|
|
|
EXPECT_FALSE(is_reachable(graph)("a", "d"));
|
|
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
|
|
|
|
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
|
|
|
|
// Assert conv_relu op in newly generated graph
|
|
|
|
// Assert conv_relu op in newly generated graph
|
|
|
|
int conv_count = 0;
|
|
|
|
int conv_count = 0;
|
|
|
|
int elementwise_add_count = 0;
|
|
|
|
int elementwise_add_count = 0;
|
|
|
@ -192,7 +195,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
|
|
|
|
|
|
|
|
|
|
|
|
SetOp(&prog, "sigmoid", {"a"}, {"b"});
|
|
|
|
SetOp(&prog, "sigmoid", {"a"}, {"b"});
|
|
|
|
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
|
|
|
|
SetOp(&prog, "conv2d", {"b", "weights"}, {"c"});
|
|
|
|
SetOp(&prog, "elementwise_add", {"c", "d"}, {"e"});
|
|
|
|
SetOp(&prog, "elementwise_add", {"d", "c"}, {"e"});
|
|
|
|
SetOp(&prog, "relu", {"e"}, {"f"});
|
|
|
|
SetOp(&prog, "relu", {"e"}, {"f"});
|
|
|
|
|
|
|
|
|
|
|
|
return prog;
|
|
|
|
return prog;
|
|
|
@ -212,7 +215,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
|
|
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "f"));
|
|
|
|
EXPECT_TRUE(is_reachable(graph)("a", "f"));
|
|
|
|
|
|
|
|
|
|
|
|
EXPECT_EQ(original_nodes_num - 4 + 1, current_nodes_num);
|
|
|
|
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added, current_nodes_num);
|
|
|
|
// Assert conv_relu op in newly generated graph
|
|
|
|
// Assert conv_relu op in newly generated graph
|
|
|
|
int conv_count = 0;
|
|
|
|
int conv_count = 0;
|
|
|
|
int elementwise_add_count = 0;
|
|
|
|
int elementwise_add_count = 0;
|
|
|
|