support unknown while subgraph

pull/1415/head
lichun 4 years ago
parent 7ed03d0d0e
commit 12cef9e9b9

@ -135,6 +135,7 @@ class HybridModel {
std::string model_name_;
GeRootModelPtr ge_root_model_;
std::map<uint32_t, NodeItem *> input_nodes_;
ComputeGraphPtr root_graph_;
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_;

@ -136,12 +136,12 @@ Status HybridModelBuilder::Build() {
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(), "[%s] Failed to RecoverGraphUnknownFlag", GetGraphName());
GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName());
GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName());
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName());
GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName());
GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName());
GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName());
GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName());
GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName());
GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName());
GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName());
GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName());
GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName());
@ -599,9 +599,10 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) {
return SUCCESS;
}
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) {
Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph) {
merged_graph = MakeShared<ComputeGraph>("MergedGraph");
for (const auto &node : root_graph.GetDirectNode()) {
merged_graph->SetGraphUnknownFlag(root_graph->GetGraphUnknownFlag());
for (const auto &node : root_graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
@ -631,7 +632,7 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
}
}
}
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph),
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, merged_graph, *subgraph),
"[%s] Failed to merge subgraph.",
subgraph->GetName().c_str());
}
@ -647,18 +648,19 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap
return a_level < b_level;
});
for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) {
for (auto &remained_subgraph : root_graph->GetAllSubgraphs()) {
GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str());
GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph),
"Failed to add subgraph [%s]",
remained_subgraph->GetName().c_str());
remained_subgraph->SetParentGraph(merged_graph);
}
return SUCCESS;
}
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph,
ComputeGraph &parent_graph,
Status HybridModelBuilder::UnfoldSubgraph(ComputeGraphPtr &root_graph,
ComputeGraphPtr &parent_graph,
ComputeGraph &sub_graph) {
auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node);
@ -687,15 +689,23 @@ Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph,
}
}
parent_graph.AddNode(sub_node);
if (!sub_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) {
for (size_t i = 0; i < sub_node->GetOpDesc()->GetSubgraphInstanceNames().size(); ++i) {
auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, i);
GE_CHECK_NOTNULL(sub_sub_graph);
sub_sub_graph->SetParentGraph(parent_graph);
}
}
parent_graph->AddNode(sub_node);
GELOGD("[%s::%s] added to parent graph: [%s].",
sub_graph.GetName().c_str(),
sub_node->GetName().c_str(),
parent_graph.GetName().c_str());
parent_graph->GetName().c_str());
sub_node->SetOwnerComputeGraph(parent_graph);
}
GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str());
root_graph.RemoveSubgraph(sub_graph.GetName());
root_graph->RemoveSubgraph(sub_graph.GetName());
return SUCCESS;
}
@ -747,14 +757,14 @@ Status HybridModelBuilder::LoadGraph() {
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize());
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs.");
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph), "Failed to unfold subgraphs.");
root_graph = std::move(merged_graph);
GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize());
}
root_graph_ = root_graph;
hybrid_model_.root_graph_ = root_graph;
// Reset node id by topological order across all subgraphs
int64_t index = 0;
for (const auto &node : root_graph->GetAllNodes()) {
@ -1030,9 +1040,13 @@ Status HybridModelBuilder::InitWeights() {
GELOGI("Init weight mem successfully, weight base %p, weight size = %zu",
weight_base,
sub_weight_buffer->GetSize());
auto root_graph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph());
hybrid_model_.weight_buffer_map_.emplace(root_graph->GetName(),std::move(sub_weight_buffer));
for (auto &node : root_graph->GetDirectNode()) {
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph());
if (subgraph != ge_root_model_->GetRootGraph()) {
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first);
}
GE_CHECK_NOTNULL(subgraph);
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer));
for (auto &node : subgraph->GetDirectNode()) {
if (node->GetType() != CONSTANT) {
continue;
}
@ -2044,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
GELOGD("[%s] Start to get parallel group from subgraph: %s",
node_item->NodeName().c_str(),
subgraph_name.c_str());
auto subgraph = root_graph_->GetSubgraph(subgraph_name);
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetAllNodes()) {
std::string parallel_group;

@ -47,8 +47,8 @@ class HybridModelBuilder {
static Status HandleDtString(const GeTensor &tensor, void *var_addr);
static Status MergeInputNodes(ComputeGraph &compute_graph);
static Status MergeNetOutputNode(ComputeGraph &compute_graph);
static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph);
static Status UnfoldSubgraphs(ComputeGraphPtr &root_graph, ComputeGraphPtr &merged_graph);
static Status UnfoldSubgraph(ComputeGraphPtr &root_graph, ComputeGraphPtr &parent_graph, ComputeGraph &sub_graph);
static Status BuildInputMapping(GraphItem &graph_item,
std::vector<NodeItem *> &data_nodes,
bool is_root_graph);
@ -100,7 +100,6 @@ class HybridModelBuilder {
NodeItem *MutableNodeItem(const NodePtr &node);
GeRootModelPtr ge_root_model_;
ComputeGraphPtr root_graph_;
std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;

@ -256,3 +256,53 @@ TEST_F(UtestGeHybrid, init_weight_success) {
HybridModelExecutor executor(model_ptr, device_id, stream);
executor.Init();
}
TEST_F(UtestGeHybrid, unfold_subgraphs_success) {
ComputeGraphPtr merged_graph = nullptr;
ComputeGraphPtr sub_sub_graph1 = std::make_shared<ComputeGraph>("while_cond");
OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA);
NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc);
ComputeGraphPtr sub_sub_graph2 = std::make_shared<ComputeGraph>("while_body");
/*OpDescPtr sub_sub_graph_while_body_const_op_desc = CreateOpDesc("body_const", CONSTANT);
NodePtr sub_sub_graph_while_body_const_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_const_op_desc);*/
OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA);
NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc);
sub_sub_graph2->SetGraphUnknownFlag(true);
/*OpDescPtr sub_sub_graph_while_body_add_op_desc = CreateOpDesc("body_add", ADD);
NodePtr sub_sub_graph_while_body_add_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_add_node);
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_data_node);
sub_sub_graph_while_body_add_node->AddLinkFrom(sub_sub_graph_while_body_const_node);*/
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE);
NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc);
sub_graph->SetGraphUnknownFlag(true);
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_cond");
sub_graph_while_node->GetOpDesc()->AddSubgraphName("while_body");
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(0, "while_cond");
sub_graph_while_node->GetOpDesc()->SetSubgraphInstanceName(1, "while_body");
ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
auto partitioned_call_op_desc = MakeShared<OpDesc>("partitioned_call", PARTITIONEDCALL);
auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc);
partitioned_call_node->GetOpDesc()->AddSubgraphName("sub_graph");
partitioned_call_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph");
root_graph->AddSubGraph(sub_sub_graph1);
root_graph->AddSubGraph(sub_sub_graph2);
sub_sub_graph1->SetParentGraph(root_graph);
sub_sub_graph2->SetParentGraph(root_graph);
sub_sub_graph1->SetParentNode(sub_graph_while_node);
sub_sub_graph2->SetParentNode(sub_graph_while_node);
root_graph->AddSubGraph(sub_graph);
sub_graph->SetParentNode(partitioned_call_node);
sub_graph->SetParentGraph(root_graph);
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(root_graph);
HybridModel hybrid_model(root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
EXPECT_EQ(hybrid_model_builder.UnfoldSubgraphs(root_graph, merged_graph), SUCCESS);
}

Loading…
Cancel
Save