From bb81990dd3dc5d4db6e6e168a1c9e4d22d60d9b0 Mon Sep 17 00:00:00 2001 From: lichun Date: Tue, 23 Mar 2021 14:38:30 +0800 Subject: [PATCH] Bugfix: support unknown control op subgraph --- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 39 +++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 3b5d19e6..c13a9003 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -245,7 +245,7 @@ TEST_F(UtestGeHybrid, init_weight_success) { ASSERT_EQ(ret,PARAM_INVALID); } - TEST_F(UtestGeHybrid, hybrid_model_executor) { +TEST_F(UtestGeHybrid, hybrid_model_executor) { ComputeGraphPtr compute_graph = MakeShared("abc"); GeRootModelPtr root_model = MakeShared(compute_graph); HybridModel model(root_model); @@ -256,3 +256,40 @@ TEST_F(UtestGeHybrid, init_weight_success) { HybridModelExecutor executor(model_ptr, device_id, stream); executor.Init(); } + +TEST_F(UtestGeHybrid, unfold_subgraphs_success) { + ComputeGraphPtr merged_graph = nullptr; + + ComputeGraphPtr root_graph = std::make_shared("root_graph"); + auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); + auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); + + ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); + OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); + NodePtr sub_graph_while_node = sub_graph->AddNode(op_desc); + + ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); + OpDescPtr sub_sub_graph_while_cond_const_op_desc = CreateOpDesc("cond_const", CONST); + NodePtr sub_sub_graph_while_cond_const_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_const_op_desc); + + ComputeGraphPtr sub_sub_graph2 = std::make_shared("while body"); + OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONST); + NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc); + OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); + NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); + OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); + NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node); + + sub_sub_graph1->AddSubGraph(sub_sub_graph1); + sub_sub_graph2->AddSubGraph(sub_sub_graph2); + + root_graph->AddSubGraph(sub_graph); + sub_graph->SetParentNode(partitioned_call_node); + + GeRootModelPtr root_model = MakeShared(root_graph); + HybridModel hybrid_model(root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); +}