Bugfix: fix null owner graph error

pull/1412/head
lichun 4 years ago
parent 322e3c1c87
commit 7a3dba72af

@ -135,6 +135,7 @@ class HybridModel {
std::string model_name_;
GeRootModelPtr ge_root_model_;
std::map<uint32_t, NodeItem *> input_nodes_;
ComputeGraphPtr root_graph_;
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_;

@ -764,7 +764,7 @@ Status HybridModelBuilder::LoadGraph() {
root_graph->GetAllNodesSize());
}
root_graph_ = root_graph;
hybrid_model_.root_graph_ = root_graph;
// Reset node id by topological order across all subgraphs
int64_t index = 0;
for (const auto &node : root_graph->GetAllNodes()) {
@ -2058,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
GELOGD("[%s] Start to get parallel group from subgraph: %s",
node_item->NodeName().c_str(),
subgraph_name.c_str());
auto subgraph = root_graph_->GetSubgraph(subgraph_name);
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetAllNodes()) {
std::string parallel_group;

@ -100,7 +100,6 @@ class HybridModelBuilder {
NodeItem *MutableNodeItem(const NodePtr &node);
GeRootModelPtr ge_root_model_;
ComputeGraphPtr root_graph_;
std::map<std::string, GeModelPtr> subgraph_models_;
std::map<std::string, NodePtr> constant_op_nodes_;
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_;

@ -1 +1 @@
Subproject commit 86781b7e8ce21d2b901406cc3619d6bea2aeb18e
Subproject commit 4ff5e3987f2e5d2980019defacaf0891861c84fc

@ -276,9 +276,9 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) {
op_desc->SetOpKernelLibName("ops_kernel_info_hccl");
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);
model.root_graph_ = compute_graph;
HybridModelBuilder builder(model);
builder.root_graph_ = compute_graph;
ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS);
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);

Loading…
Cancel
Save