|
|
|
@ -28,7 +28,7 @@ class TestHWBatchNormBertFission : public BackendCommon {
|
|
|
|
|
UT::PyFuncGraphFetcher get_py_fun_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) {
|
|
|
|
|
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fission) {
|
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before");
|
|
|
|
|
EXPECT_NE(g, nullptr);
|
|
|
|
|
std::vector<int> shp_x{32, 64, 112, 112};
|
|
|
|
@ -40,6 +40,23 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) {
|
|
|
|
|
args_spec_list.push_back(y_abstract);
|
|
|
|
|
}
|
|
|
|
|
auto kg = GetKernelGraph(g, args_spec_list);
|
|
|
|
|
auto ret = kg->get_return();
|
|
|
|
|
EXPECT_NE(ret, nullptr);
|
|
|
|
|
auto make_tuple0 = ret->input(1);
|
|
|
|
|
EXPECT_NE(make_tuple0, nullptr);
|
|
|
|
|
auto tuple_getitem0 = make_tuple0->cast<CNodePtr>()->input(1);
|
|
|
|
|
EXPECT_NE(tuple_getitem0, nullptr);
|
|
|
|
|
auto make_tuple1 = tuple_getitem0->cast<CNodePtr>()->input(1);
|
|
|
|
|
EXPECT_NE(make_tuple1, nullptr);
|
|
|
|
|
auto tuple_getitem1 = make_tuple1->cast<CNodePtr>()->input(1);
|
|
|
|
|
EXPECT_NE(tuple_getitem1, nullptr);
|
|
|
|
|
auto bn = tuple_getitem1->cast<CNodePtr>()->input(1);
|
|
|
|
|
EXPECT_NE(bn, nullptr);
|
|
|
|
|
auto bn_cnode = bn->cast<CNodePtr>();
|
|
|
|
|
EXPECT_NE(bn_cnode, nullptr);
|
|
|
|
|
auto inputs = bn_cnode->inputs();
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs(inputs.begin(), inputs.begin() + 4);
|
|
|
|
|
bn_cnode->set_inputs(new_inputs);
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
@ -50,5 +67,27 @@ TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_fusion) {
|
|
|
|
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "after");
|
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TEST_F(TestHWBatchNormBertFission, test_fused_batch_norm_no_fission) {
|
|
|
|
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_batch_norm_bert_fission", "before");
|
|
|
|
|
EXPECT_NE(g, nullptr);
|
|
|
|
|
std::vector<int> shp_x{32, 64, 112, 112};
|
|
|
|
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_x);
|
|
|
|
|
std::vector<int> shp_y{64};
|
|
|
|
|
auto y_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp_y);
|
|
|
|
|
AbstractBasePtrList args_spec_list{x_abstract};
|
|
|
|
|
for (size_t i = 0; i < 4; ++i) {
|
|
|
|
|
args_spec_list.push_back(y_abstract);
|
|
|
|
|
}
|
|
|
|
|
auto kg = GetKernelGraph(g, args_spec_list);
|
|
|
|
|
|
|
|
|
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
|
|
|
|
auto pm = std::make_shared<opt::PassManager>();
|
|
|
|
|
pm->AddPass(std::make_shared<opt::BatchNormBertFission>());
|
|
|
|
|
optimizer->AddPassManager(pm);
|
|
|
|
|
FuncGraphPtr new_graph = optimizer->Optimize(kg);
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(CheckEqualGraph(kg, new_graph));
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|