diff --git a/tests/ut/ge/generator/ge_generator_unittest.cc b/tests/ut/ge/generator/ge_generator_unittest.cc index 3daa5592..bb8a0513 100644 --- a/tests/ut/ge/generator/ge_generator_unittest.cc +++ b/tests/ut/ge/generator/ge_generator_unittest.cc @@ -20,6 +20,11 @@ #define protected public #include "generator/ge_generator.h" #include "graph/utils/tensor_utils.h" +#include "graph/attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "../graph/passes/graph_builder_utils.h" +#include "../graph/manager/graph_manager.h using namespace std; @@ -31,6 +36,16 @@ class UtestGeGenerator : public testing::Test { void TearDown() {} }; +namespace { +ComputeGraphPtr MakeGraph() { + ge::ut::GraphBuilder builder("graph"); + auto data = builder.AddNode("data", "Data", 1, 1); + auto addn1 = builder.AddNode("addn1", "AddN", 1, 1); + builder.AddDataEdge(data, 0, addn1, 0); + return builder.GetGraph(); +} +} // namespace + /* TEST_F(UtestGeGenerator, test_build_single_op_offline) { GeTensorDesc tensor_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); @@ -71,4 +86,28 @@ TEST_F(UtestGeGenerator, test_build_single_op_online) { ModelBufferData model_buffer; EXPECT_EQ(generator.BuildSingleOpModel(op_desc, inputs, outputs, ENGINE_AIVECTOR, model_buffer), FAILED); } + +TEST_F(UtestGeGenerator, test_graph_manager) { + GraphManager graph_manager; + GraphPartitioner graph_partitioner; + + auto root_graph = MakeGraph(); + auto sub_graph = MakeGraph(); + root_graph->AddSubGraph(sub_graph); + + auto sgi = MakeShared(); + // set engine name + sgi->SetEngineName("AIcoreEngine"); + sgi->SetSubGraph(sub_graph); + + auto sgi_gelocal = MakeShared(); + // set engine name + sgi_gelocal->SetEngineName("GELOCAL"); + sgi_gelocal->SetSubGraph(sub_graph); + + graph_partitioner.graph_2_input_subgraph_[root_graph] = sgi_gelocal; + graph_partitioner.graph_2_subgraph_list_.insert({root_graph, {sgi, sgi_gelocal}}); + graph_partitioner.graph_2_subgraph_list_.insert({sub_graph, {sgi, sgi_gelocal}}); + EXPECT_EQ(graph_manager.ConvertGraphToFile(root_graph, graph_partitioner, "./"), GRAPH_SUCCESS); +} } // namespace ge