pull/1379/head
chuxing 4 years ago
parent e49e09d7a7
commit 6eddcd2d95

@ -1637,6 +1637,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem
auto temp_graph = MakeShared<ComputeGraph>("temp");
GE_CHECK_NOTNULL(temp_graph);
auto wrapper_node = temp_graph->AddNode(wrapper_op_desc);
wrapper_op_desc->SetId(parent_node_item->node_id);
GeModelPtr ge_model = subgraph_models_[subgraph_name];
GE_CHECK_NOTNULL(ge_model);
hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model);
@ -1916,7 +1917,6 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root
NodeItem *node_item = nullptr;
GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item));
GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item));
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item));
GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task
node_item->input_start = input_start;
@ -2069,22 +2069,17 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) {
}
Status HybridModelBuilder::ParseDependentByParallelGroup() {
for (auto &it : hybrid_model_.node_items_) {
GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get()));
}
for (const auto &it : node_to_parallel_groups_) {
auto node_item = it.first;
auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node);
auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node);
for (const auto &parallel_group : it.second) {
auto &dependent_nodes = parallel_group_to_nodes_[parallel_group];
NodeItem *nearest_dep_node = nullptr;
int max_id = -1;
for (auto &dep_node : dependent_nodes) {
if (node_item == dep_node) {
continue;
}
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node);
if (src_engine_type == dst_engine_type) {
continue;
}
if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) {
nearest_dep_node = dep_node;
max_id = dep_node->node_id;
@ -2092,10 +2087,12 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() {
}
if (nearest_dep_node != nullptr) {
GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]",
parallel_group.c_str(),
nearest_dep_node->NodeName().c_str(),
node_item->NodeName().c_str());
GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str());
auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node);
if (src_engine_type == dst_executor_type) {
GELOGD("No need to add dependency for nodes with same executor type");
continue;
}
auto &deps = node_item->dependents_for_execution;
if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) {
GELOGD("%s->%s Already has dependency, skip it",
@ -2105,6 +2102,10 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() {
}
nearest_dep_node->has_observer = true;
deps.emplace_back(nearest_dep_node->node);
GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]",
parallel_group.c_str(),
nearest_dep_node->NodeName().c_str(),
node_item->NodeName().c_str());
}
}
}

Loading…
Cancel
Save