|
|
|
@ -39,7 +39,7 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) {
|
|
|
|
|
* return transpose
|
|
|
|
|
*/
|
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before");
|
|
|
|
|
std::vector<int> shp{2, 4, 8, 16};
|
|
|
|
|
std::vector<int> shp{2, 2, 16, 16};
|
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
|
AbstractBasePtrList args_spec_list{x_abstract};
|
|
|
|
|
auto kg = GetKernelGraph(g, args_spec_list);
|
|
|
|
@ -59,5 +59,26 @@ TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_fusion) {
|
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "after");
|
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWReshapeTransposeFusion, test_reshape_transpose_no_fusion) {
|
|
|
|
|
/*
|
|
|
|
|
* def before(input0, input1):
|
|
|
|
|
* reshape = Reshape(input0, input1)
|
|
|
|
|
* transpose = Transpose(reshape)
|
|
|
|
|
* return transpose
|
|
|
|
|
*/
|
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_reshape_transpose_fusion", "before");
|
|
|
|
|
std::vector<int> shp{2, 4, 8, 16};
|
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
|
AbstractBasePtrList args_spec_list{x_abstract};
|
|
|
|
|
auto kg = GetKernelGraph(g, args_spec_list);
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
|
pm->AddPass(std::make_shared<opt::ReshapeTransposeFusion>());
|
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|