|
|
|
@ -14,7 +14,6 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
|
|
|
|
|
|
|
|
|
|
#include <gmock/gmock.h>
|
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
|
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
|
|
|
|
|
|
|
|
|
@ -97,11 +96,11 @@ void TestMain(bool with_xshapes) {
|
|
|
|
|
|
|
|
|
|
auto check = [&matmul_op_desc](std::string a) {
|
|
|
|
|
std::string shape_str = "fused_reshape_" + a;
|
|
|
|
|
EXPECT_THAT(matmul_op_desc->GetAttrIfExists<std::vector<int>>(shape_str),
|
|
|
|
|
testing::ElementsAre(0, 0, 12, 64));
|
|
|
|
|
auto shape = matmul_op_desc->GetAttrIfExists<std::vector<int>>(shape_str);
|
|
|
|
|
EXPECT_EQ(shape, (std::vector<int>{0, 0, 12, 64}));
|
|
|
|
|
std::string axis_str = "fused_transpose_" + a;
|
|
|
|
|
EXPECT_THAT(matmul_op_desc->GetAttrIfExists<std::vector<int>>(axis_str),
|
|
|
|
|
testing::ElementsAre(0, 2, 1, 3));
|
|
|
|
|
auto axis = matmul_op_desc->GetAttrIfExists<std::vector<int>>(axis_str);
|
|
|
|
|
EXPECT_EQ(axis, (std::vector<int>{0, 2, 1, 3}));
|
|
|
|
|
};
|
|
|
|
|
check("X");
|
|
|
|
|
check("Y");
|
|
|
|
|