Reshape transpose matmul coverage (#24970)

* remove gmock from ut

test=develop

* coverage enabled for r+t+m fuse pass

test=develop
revert-24981-add_device_attr_for_regulization
Sylwester Fraczek 5 years ago committed by GitHub
parent 2e238c6eed
commit 53d563a0fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -157,8 +157,6 @@ endif()
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
if(NOT WITH_COVERAGE)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
endif()
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
endif ()

@ -14,7 +14,6 @@
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
@ -97,11 +96,11 @@ void TestMain(bool with_xshapes) {
auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a;
EXPECT_THAT(matmul_op_desc->GetAttrIfExists<std::vector<int>>(shape_str),
testing::ElementsAre(0, 0, 12, 64));
auto shape = matmul_op_desc->GetAttrIfExists<std::vector<int>>(shape_str);
EXPECT_EQ(shape, (std::vector<int>{0, 0, 12, 64}));
std::string axis_str = "fused_transpose_" + a;
EXPECT_THAT(matmul_op_desc->GetAttrIfExists<std::vector<int>>(axis_str),
testing::ElementsAre(0, 2, 1, 3));
auto axis = matmul_op_desc->GetAttrIfExists<std::vector<int>>(axis_str);
EXPECT_EQ(axis, (std::vector<int>{0, 2, 1, 3}));
};
check("X");
check("Y");

Loading…
Cancel
Save