diff --git a/ge/CMakeLists.txt b/ge/CMakeLists.txt index bd9edd86..9f2b21d7 100755 --- a/ge/CMakeLists.txt +++ b/ge/CMakeLists.txt @@ -28,59 +28,59 @@ set(PROTO_HEADER_LIST "${METADEF_DIR}/proto/op_mapping_info.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) -protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) -protobuf_generate(ge_client PROTO_CLIENT_HEADER_SRCS PROTO_CLIENT_HEADER_HDRS ${PROTO_HEADER_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) +#protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) +#protobuf_generate(ge_client PROTO_CLIENT_HEADER_SRCS PROTO_CLIENT_HEADER_HDRS ${PROTO_HEADER_LIST}) -if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) -############ libge_proto_common.a ############ -add_library(ge_proto_common STATIC - ${PROTO_HEADER_HDRS} - ${PROTO_SRCS} -) - -target_compile_definitions(ge_proto_common PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - google=ascend_private -) - -target_compile_options(ge_proto_common PRIVATE - -O2 - -fno-common -) - -target_link_libraries(ge_proto_common PRIVATE - $ - ascend_protobuf -) - -############ libge_proto_client.a ############ -add_library(ge_proto_client STATIC - ${PROTO_CLIENT_HEADER_HDRS} - ${PROTO_CLIENT_SRCS} -) - -target_compile_definitions(ge_proto_client PRIVATE - PROTOBUF_INLINE_NOT_IN_HEADERS=0 - google=ascend_private -) - -target_include_directories(ge_proto_client PRIVATE - ${CMAKE_BINARY_DIR}/proto/ge_client - ${CMAKE_BINARY_DIR}/proto/ge_client/proto -) - -target_compile_options(ge_proto_client PRIVATE - -O2 - -fno-common -) - -target_link_libraries(ge_proto_client PRIVATE - $ - ascend_protobuf -) -endif () +#if (NOT ENABLE_D AND NOT ENABLE_ACL AND NOT ENABLE_MS_TESTCASES) +############# libge_proto_common.a ############ +#add_library(ge_proto_common STATIC +# ${PROTO_HEADER_HDRS} +# ${PROTO_SRCS} +#) +# +#target_compile_definitions(ge_proto_common PRIVATE +# PROTOBUF_INLINE_NOT_IN_HEADERS=0 +# google=ascend_private +#) +# +#target_compile_options(ge_proto_common PRIVATE +# -O2 +# -fno-common +#) +# +#target_link_libraries(ge_proto_common PRIVATE +# $ +# ascend_protobuf +#) +# +############# libge_proto_client.a ############ +#add_library(ge_proto_client STATIC +# ${PROTO_CLIENT_HEADER_HDRS} +# ${PROTO_CLIENT_SRCS} +#) +# +#target_compile_definitions(ge_proto_client PRIVATE +# PROTOBUF_INLINE_NOT_IN_HEADERS=0 +# google=ascend_private +#) +# +#target_include_directories(ge_proto_client PRIVATE +# ${CMAKE_BINARY_DIR}/proto/ge_client +# ${CMAKE_BINARY_DIR}/proto/ge_client/proto +#) +# +#target_compile_options(ge_proto_client PRIVATE +# -O2 +# -fno-common +#) +# +#target_link_libraries(ge_proto_client PRIVATE +# $ +# ascend_protobuf +#) +#endif () ################################################################## set(TRAIN_SRC_LIST diff --git a/ge/common/CMakeLists.txt b/ge/common/CMakeLists.txt index 75cb8ad1..585a42cb 100755 --- a/ge/common/CMakeLists.txt +++ b/ge/common/CMakeLists.txt @@ -15,8 +15,8 @@ set(PROTO_LIST "${METADEF_DIR}/proto/tensorflow/versions.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) set(SRC_LIST "context/ctx.cc" diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index fccdb57b..01c7de95 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -60,7 +60,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { default: std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + " and FORMAT_FRACTAL_NZ is not supported."; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return false; } } diff --git a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc index c36bffb5..36bea872 100755 --- a/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc +++ b/ge/common/formats/format_transfers/format_transfer_fractal_zz.cc @@ -59,7 +59,7 @@ bool CheckShape(Format format, const ShapeVector &shape) { default: std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + " and FORMAT_FRACTAL_ZZ is not supported."; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); return false; } } diff --git a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc index b09fd168..6817713a 100755 --- a/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc +++ b/ge/common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc @@ -92,7 +92,8 @@ Status CheckArgsForNhwcToNc1hwc0(const TransArgs &args) { Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { std::shared_ptr dst(new (std::nothrow) uint8_t[total_size], std::default_delete()); if (dst == nullptr) { - GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", + GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, + "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld, shape %s", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), total_size, ShapeToString(args.dst_shape).c_str()); return ACL_ERROR_GE_MEMORY_ALLOCATION; diff --git a/ge/common/formats/format_transfers/format_transfer_transpose.cc b/ge/common/formats/format_transfers/format_transfer_transpose.cc index 694777f3..49bb5cd6 100755 --- a/ge/common/formats/format_transfers/format_transfer_transpose.cc +++ b/ge/common/formats/format_transfers/format_transfer_transpose.cc @@ -50,21 +50,21 @@ std::map>> perm_args{ bool IsShapeArgValid(const std::vector &src_shape, const std::vector &perm_arg) { if (src_shape.empty()) { std::string error = "Failed to transpose, empty src shape"; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); - GELOGE(PARAM_INVALID, "Failed to transpose, empty src shape"); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); + GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "Failed to transpose, empty src shape"); return false; } for (auto dim : src_shape) { if (dim < 0) { std::string error = "Failed to transpose, negative dim in src shape " + FmtToStr(ShapeToString(src_shape)); - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return false; } } if (perm_arg.size() != src_shape.size()) { std::string error = "Failed to transpose, the size of src shape" + FmtToStr(src_shape.size()) + " and perm arg" + FmtToStr(perm_arg.size()) + " are different"; - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_SHAPE_INVALID, error.c_str()); return false; } @@ -73,7 +73,7 @@ bool IsShapeArgValid(const std::vector &src_shape, const std::vector(perm) >= perm_arg.size() || ++exists[perm] > 1) { std::string error = "Failed to transpose, duplicated perm arg " + FmtToStr(perm) + ", perm arg " + FmtToStr(JoinToString(perm_arg)); - GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str()); + GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_PARAM_INVALID, error.c_str()); return false; } } @@ -82,11 +82,11 @@ bool IsShapeArgValid(const std::vector &src_shape, const std::vector &src_shape, DataType src_data_type, const std::vector &perm_arg) { if (src == nullptr) { - GELOGE(PARAM_INVALID, "Failed to transpose, the src is null"); + GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Failed to transpose, the src is null"); return false; } if (GetSizeByDataType(src_data_type) < 0) { - GELOGE(UNSUPPORTED, "Failed to transpose, the data type %s is not support", + GELOGE(ACL_ERROR_GE_DATATYPE_INVALID, "Failed to transpose, the data type %s is not support", TypeUtils::DataTypeToSerialString(src_data_type).c_str()); return false; } diff --git a/ge/executor/CMakeLists.txt b/ge/executor/CMakeLists.txt index 363900d0..31f8dc4b 100644 --- a/ge/executor/CMakeLists.txt +++ b/ge/executor/CMakeLists.txt @@ -7,8 +7,8 @@ set(PROTO_LIST "${METADEF_DIR}/proto/dump_task.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_static PROTO_STATIC_SRCS PROTO_STATIC_HDRS ${PROTO_LIST}) set(SRC_LIST "ge_executor.cc" diff --git a/ge/ge_local_engine/CMakeLists.txt b/ge/ge_local_engine/CMakeLists.txt index ab767ccb..affd8f5a 100755 --- a/ge/ge_local_engine/CMakeLists.txt +++ b/ge/ge_local_engine/CMakeLists.txt @@ -19,9 +19,9 @@ set(OPS_KERNEL_SRC_LIST "ops_kernel_store/op/no_op.cc" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge_ops_shared PROTO_OPS_SHARED_SRCS PROTO_OPS_SHARED_HDRS ${PROTO_LIST}) -protobuf_generate(ge_ops_static PROTO_OPS_STATIC_SRCS PROTO_OPS_STATIC_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_ops_shared PROTO_OPS_SHARED_SRCS PROTO_OPS_SHARED_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_ops_static PROTO_OPS_STATIC_SRCS PROTO_OPS_STATIC_HDRS ${PROTO_LIST}) ############ libge_local_engine.so ############ add_library(ge_local_engine SHARED ${SRC_LIST} ${PROTO_HDRS}) diff --git a/ge/host_cpu_engine/CMakeLists.txt b/ge/host_cpu_engine/CMakeLists.txt index 8d84ee28..950a1e5c 100644 --- a/ge/host_cpu_engine/CMakeLists.txt +++ b/ge/host_cpu_engine/CMakeLists.txt @@ -2,8 +2,8 @@ set(PROTO_LIST "${METADEF_DIR}/proto/task.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) -protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge_atcstub PROTO_ATCSTUB_SRCS PROTO_ATCSTUB_HDRS ${PROTO_LIST}) set(SRC_LIST "engine/host_cpu_engine.cc" diff --git a/ge/hybrid/model/hybrid_model.cc b/ge/hybrid/model/hybrid_model.cc index a0217d52..86acc260 100644 --- a/ge/hybrid/model/hybrid_model.cc +++ b/ge/hybrid/model/hybrid_model.cc @@ -333,7 +333,7 @@ TensorValue *HybridModel::GetConstant(const NodePtr &node) const { return nullptr; } - auto it = constant_tensors_.find(node); + auto it = constant_tensors_.find(node->GetName()); if (it == constant_tensors_.end()) { GELOGD("constant not found, node name = [%s]", node->GetName().c_str()); return nullptr; diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index fae53679..e8c005f6 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -138,7 +138,7 @@ class HybridModel { std::map device_variable_nodes_; //lint !e148 std::map host_variable_nodes_; //lint !e148 std::map> variable_tensors_; - std::map> constant_tensors_; + std::map> constant_tensors_; std::map> task_defs_; std::map known_shape_sub_models_; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index a3b1da20..def32766 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -617,9 +617,9 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { +Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) { merged_graph = MakeShared("MergedGraph"); - for (const auto &node : root_graph.GetDirectNode()) { + for (const auto &node : root_graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -649,7 +649,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap } } } - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph), "[%s] Failed to merge subgraph.", subgraph->GetName().c_str()); } @@ -665,7 +665,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap return a_level < b_level; }); - for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { + for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) { GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), "Failed to add subgraph [%s]", @@ -675,8 +675,8 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap return SUCCESS; } -Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, - ComputeGraph &parent_graph, +Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph, + ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph) { auto parent_node = sub_graph.GetParentNode(); GE_CHECK_NOTNULL(parent_node); @@ -705,15 +705,24 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, } } - parent_graph.AddNode(sub_node); + if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { + for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) { + auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i); + GE_CHECK_NOTNULL(sub_sub_graph); + sub_sub_graph->SetParentGraph(root_graph); + } + } + + parent_graph->AddNode(sub_node); GELOGD("[%s::%s] added to parent graph: [%s].", sub_graph.GetName().c_str(), sub_node->GetName().c_str(), - parent_graph.GetName().c_str()); + parent_graph->GetName().c_str()); + sub_node->SetOwnerComputeGraph(root_graph); } GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); - root_graph.RemoveSubgraph(sub_graph.GetName()); + root_graph->RemoveSubgraph(sub_graph.GetName()); return SUCCESS; } @@ -765,7 +774,7 @@ Status HybridModelBuilder::LoadGraph() { GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); - GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs."); root_graph = std::move(merged_graph); GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), @@ -1035,6 +1044,14 @@ Status HybridModelBuilder::InitWeights() { sub_weight_buffer->GetSize()); auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph()); hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer)); + + std::map name_to_node; + for (const auto &subgraph : ge_root_model_->GetRootGraph()->GetAllSubgraphs()) { + for (const auto &node : subgraph->GetAllNodes()) { + name_to_node.insert(make_pair(node->GetName(), node)); + } + } + for (auto &node : root_graph->GetDirectNode()) { if (node->GetType() != CONSTANT) { continue; @@ -1065,10 +1082,25 @@ Status HybridModelBuilder::InitWeights() { auto tensor_buffer = TensorBuffer::Create(weight_base + data_offset, tensor_size); GE_CHECK_NOTNULL(tensor_buffer); + + if (tensor_size > 0) { + auto tensor = std::shared_ptr( + new (std::nothrow)GeTensor(tensor_desc, weight_buffer.GetData() + data_offset, tensor_size)); + OpDescPtr op_desc = nullptr; + auto iter = name_to_node.find(node->GetName()); + if (iter != name_to_node.end()) { + op_desc = iter->second->GetOpDesc(); + if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, std::move(tensor))) { + GELOGE(FAILED, "Set attr ATTR_NAME_WEIGHTS failed."); + return FAILED; + } + } + } + std::unique_ptr constant_tensor(new (std::nothrow)TensorValue(std::move(tensor_buffer))); GE_CHECK_NOTNULL(constant_tensor); constant_tensor->SetName("Constant_" + op_desc->GetName()); - hybrid_model_.constant_tensors_.emplace(node, std::move(constant_tensor)); + hybrid_model_.constant_tensors_.emplace(node->GetName(), std::move(constant_tensor)); GELOGD("[%s] Constant node [%s] added, size = %ld", GetGraphName(), node->GetName().c_str(), tensor_size); } } diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 313d5ca6..a6b2b25a 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -47,8 +47,8 @@ class HybridModelBuilder { static Status HandleDtString(const GeTensor &tensor, void *var_addr); static Status MergeInputNodes(ComputeGraph &compute_graph); static Status MergeNetOutputNode(ComputeGraph &compute_graph); - static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); - static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); + static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph); + static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph); static Status BuildInputMapping(GraphItem &graph_item, std::vector &data_nodes, bool is_root_graph); diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index 87589859..1e8a6cc5 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -5,7 +5,7 @@ set(PROTO_LIST "${METADEF_DIR}/proto/task.proto" ) -protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) +#protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) set(SRC_LIST "main.cc" diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 3b5d19e6..de84342c 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -245,7 +245,7 @@ TEST_F(UtestGeHybrid, init_weight_success) { ASSERT_EQ(ret,PARAM_INVALID); } - TEST_F(UtestGeHybrid, hybrid_model_executor) { +TEST_F(UtestGeHybrid, hybrid_model_executor) { ComputeGraphPtr compute_graph = MakeShared("abc"); GeRootModelPtr root_model = MakeShared(compute_graph); HybridModel model(root_model); @@ -256,3 +256,53 @@ TEST_F(UtestGeHybrid, init_weight_success) { HybridModelExecutor executor(model_ptr, device_id, stream); executor.Init(); } + +TEST_F(UtestGeHybrid, unfold_subgraphs_success) { + ComputeGraphPtr merged_graph = nullptr; + + ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); + OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); + NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); + + ComputeGraphPtr sub_sub_graph2 = std::make_shared("while_body"); + /*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT); + NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/ + OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); + NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); + sub_sub_graph2->SetGraphUnknownFlag(true); + /*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD); + NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node); + sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/ + + ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); + OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE); + NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); + sub_graph->SetGraphUnknownFlag(true); + sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond"); + sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body"); + sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond"); + sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body"); + + ComputeGraphPtr root_graph = std::make_shared("root_graph"); + auto partitioned_call_op_desc = MakeShared("partitioned_call", PARTITIONEDCALL); + auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); + partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph"); + partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); + + root_graph->AddSubGraph(sub_sub_graph1); + root_graph->AddSubGraph(sub_sub_graph2); + sub_sub_graph1->SetParentGraph(root_graph); + sub_sub_graph2->SetParentGraph(root_graph); + sub_sub_graph1->SetParentNode(sub_graph_while_node); + sub_sub_graph2->SetParentNode(sub_graph_while_node); + + root_graph->AddSubGraph(sub_graph); + sub_graph->SetParentNode(partitioned_call_node); + sub_graph->SetParentGraph(root_graph); + + GeRootModelPtr root_model = MakeShared(root_graph); + HybridModel hybrid_model(root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS); +}