|
|
|
@ -49,5 +49,47 @@ TEST_F(TestHWOptSoftmaxGradExtFusion, test_fusion) {
|
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion", "after");
|
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWOptSoftmaxGradExtFusion, test_fusion_v2) {
|
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion_v2", "before");
|
|
|
|
|
EXPECT_NE(g, nullptr);
|
|
|
|
|
std::vector<int> shp{1, 1, 1, 1};
|
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
|
for (size_t i = 0; i < 3; ++i) {
|
|
|
|
|
args_spec_list.push_back(x_abstract);
|
|
|
|
|
}
|
|
|
|
|
auto fg = GetKernelGraph(g, args_spec_list);
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
|
pm->AddPass(std::make_shared<opt::SoftmaxGradExtFusionV2>());
|
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion_v2", "after");
|
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWOptSoftmaxGradExtFusion, test_fusion_v3) {
|
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion_v3", "before");
|
|
|
|
|
EXPECT_NE(g, nullptr);
|
|
|
|
|
std::vector<int> shp{1, 1, 1, 1};
|
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
|
for (size_t i = 0; i < 3; ++i) {
|
|
|
|
|
args_spec_list.push_back(x_abstract);
|
|
|
|
|
}
|
|
|
|
|
auto fg = GetKernelGraph(g, args_spec_list);
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
|
pm->AddPass(std::make_shared<opt::SoftmaxGradExtFusionV3>());
|
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_softmax_grad_ext_fusion_v3", "after");
|
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|