diff --git a/ge/graph/passes/base_pass.cc b/ge/graph/passes/base_pass.cc index 3b854c18..0868b729 100755 --- a/ge/graph/passes/base_pass.cc +++ b/ge/graph/passes/base_pass.cc @@ -30,8 +30,15 @@ constexpr int kMaxRePassTimes = 10000; constexpr size_t kMaxOneInNodes = 1000; // Each iteration, we take about 0.3k memory on the stack, we should change the recursion to loop later constexpr int kMaxRecursiveDepth = 20; +struct DuringPassNodeSets { + std::unordered_set nodes_seen; + std::unordered_set nodes_deleted; + std::unordered_set nodes_re_pass; + std::unordered_set nodes_re_pass_immediately; + std::unordered_set nodes_last; +}; -void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &input_edge_nodes, +void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque &input_edge_nodes, std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { nodes_last.clear(); for (auto &node : graph->GetDirectNode()) { @@ -40,7 +47,7 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &i } size_t in_nums = node->GetInNodes().size(); if (in_nums == 0) { - input_edge_nodes.push(node); + input_edge_nodes.push_back(node); nodes_seen.insert(node.get()); } else if (in_nums > kMaxOneInNodes) { nodes_last.insert(node); @@ -48,7 +55,7 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::queue &i } } -void AddNextIterNodes(const Node::Vistor &nodes, std::queue &nodes_to_pass, +void AddNextIterNodes(const Node::Vistor &nodes, std::deque &nodes_to_pass, std::unordered_set &nodes_seen, std::unordered_set &nodes_last) { for (auto &node : nodes) { if (node == nullptr) { @@ -60,13 +67,30 @@ void AddNextIterNodes(const Node::Vistor &nodes, std::queue &n bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { - nodes_to_pass.push(node); + nodes_to_pass.push_back(node); } } } -Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unordered_set &nodes_re_pass, - std::unordered_set &nodes_deleted, std::unordered_set &nodes_seen) { +void PushToRePassIfSeen(NodePtr &node, const std::pair &name_to_pass, + std::unordered_set &nodes_seen, std::unordered_set &nodes_to_re_pass, + std::unordered_set &nodes_re_pass) { + for (const auto &node_to_re_pass : nodes_to_re_pass) { + if (node_to_re_pass == nullptr) { + GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), + node->GetName().c_str(), node->GetType().c_str()); + continue; + } + if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { + GELOGD("The node %s will be re-pass.", node_to_re_pass->GetName().c_str()); + nodes_re_pass.insert(node_to_re_pass); + } else { + GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); + } + } +} + +Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNodeSets &during_pass_node_set) { if (node == nullptr) { GELOGE(FAILED, "parameter is null."); return FAILED; @@ -90,22 +114,15 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder } auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); - for (const auto &node_to_re_pass : nodes_to_re_pass) { - if (node_to_re_pass == nullptr) { - GELOGW("Found null re-pass node when executing %s on node %s type %s", name_to_pass.first.c_str(), - node->GetName().c_str(), node->GetType().c_str()); - continue; - } - if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) { - GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str()); - nodes_re_pass.insert(node_to_re_pass); - } else { - GELOGD("The node %s are not all seen, don't set repass this time", node_to_re_pass->GetName().c_str()); - } - } + PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, + during_pass_node_set.nodes_re_pass); + + auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); + PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, + during_pass_node_set.nodes_re_pass_immediately); auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); - nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); + during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); if (nodes_deleted_by_pass.count(node) > 0) { GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), name_to_pass.first.c_str()); @@ -181,36 +198,33 @@ Status GEPass::Run(const NamesToPass &names_to_passes) { Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); - std::queue nodes; - std::unordered_set nodes_seen; - std::unordered_set nodes_deleted; - std::unordered_set nodes_re_pass; - std::unordered_set nodes_last; - GetAllNodesNoInputEdge(graph_, nodes, nodes_seen, nodes_last); + std::deque nodes; + DuringPassNodeSets during_pass_node_set; + GetAllNodesNoInputEdge(graph_, nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); GELOGD("Start points count %zu", nodes.size()); int re_pass_times = 0; do { - for (auto &node : nodes_re_pass) { - nodes.push(node); - nodes_seen.insert(node.get()); + for (auto &node : during_pass_node_set.nodes_re_pass) { + nodes.push_back(node); + during_pass_node_set.nodes_seen.insert(node.get()); } - nodes_re_pass.clear(); + during_pass_node_set.nodes_re_pass.clear(); while (!nodes.empty()) { NodePtr node = nodes.front(); - nodes.pop(); + nodes.pop_front(); - (void)nodes_re_pass.erase(node); + (void)during_pass_node_set.nodes_re_pass.erase(node); GE_IF_BOOL_EXEC(node == nullptr, GELOGW("node is null"); continue); - if (nodes_deleted.count(node) > 0) { + if (during_pass_node_set.nodes_deleted.count(node) > 0) { GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); continue; } - AddNextIterNodes(node->GetOutNodes(), nodes, nodes_seen, nodes_last); + AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); - auto ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_deleted, nodes_seen); + auto ret = RunPasses(node, names_to_passes, during_pass_node_set); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", node->GetName().c_str(), node->GetType().c_str(), ret); @@ -227,7 +241,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { if (has_sub_graph) { GELOGD("There are subgraphs on node %s, run passes for for the second time", node->GetName().c_str()); SetFlagOption(kOptimizeAfterSubGraph, names_to_passes); - ret = RunPasses(node, names_to_passes, nodes_re_pass, nodes_deleted, nodes_seen); + ret = RunPasses(node, names_to_passes, during_pass_node_set); if (ret != SUCCESS) { GELOGE(ret, "Failed to process passes on node %s type %s, error code: %u", node->GetName().c_str(), node->GetType().c_str(), ret); @@ -239,16 +253,21 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { // should be called each time at the begin of the iteration ClearOption(names_to_passes); } + for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { + GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); + nodes.push_front(node); + } + during_pass_node_set.nodes_re_pass_immediately.clear(); } - for (auto &node : nodes_last) { - bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); - if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { - nodes.push(node); + for (auto &node : during_pass_node_set.nodes_last) { + bool all_in_nodes_seen = node->IsAllInNodesSeen(during_pass_node_set.nodes_seen); + if (all_in_nodes_seen && during_pass_node_set.nodes_seen.insert(node.get()).second) { + nodes.push_back(node); } } - nodes_last.clear(); - } while ((!nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); + during_pass_node_set.nodes_last.clear(); + } while ((!during_pass_node_set.nodes_re_pass.empty() || !nodes.empty()) && ++re_pass_times < kMaxRePassTimes); if (re_pass_times == kMaxRePassTimes) { GELOGW("re_pass_times should not come to %d", kMaxRePassTimes); diff --git a/ge/graph/passes/base_pass.h b/ge/graph/passes/base_pass.h index bb41691d..89a364a9 100644 --- a/ge/graph/passes/base_pass.h +++ b/ge/graph/passes/base_pass.h @@ -53,6 +53,8 @@ class BaseNodePass { std::unordered_set GetNodesNeedRePass() { return nodes_need_re_pass_; } + std::unordered_set GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } + std::unordered_set GetNodesDeleted() { return nodes_deleted_; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } @@ -79,6 +81,14 @@ class BaseNodePass { /// void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } + /// + /// Add a node to be optimized immediately again. If you add a new node to the graph, or + /// change a node connections, and you want to make sure the node will be + /// optimized by other passes, call this function. + /// @param node + /// + void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } + /// /// Add a node and it's input/output data nodes to be optimized again. /// @param node @@ -109,6 +119,7 @@ class BaseNodePass { private: std::unordered_set nodes_need_re_pass_; + std::unordered_set nodes_need_re_pass_immediately_; std::unordered_set nodes_deleted_; std::map options_; }; diff --git a/ge/graph/passes/infershape_pass.cc b/ge/graph/passes/infershape_pass.cc index 7b8f7b50..a54a15c1 100755 --- a/ge/graph/passes/infershape_pass.cc +++ b/ge/graph/passes/infershape_pass.cc @@ -25,6 +25,7 @@ namespace ge { Status InferShapePass::Run(NodePtr &node) { + // kOptimizeAfterSubGraph exist means after subgraph auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); if (ret != GRAPH_SUCCESS) { // select INFERSHAPE failed info @@ -41,6 +42,20 @@ Status InferShapePass::Run(NodePtr &node) { GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } + bool need_repass = false; + auto has_attr = AttrUtils::GetBool(node->GetOpDesc(), "need_infer_again_", need_repass); + if (has_attr) { + if (!OptionExists(kOptimizeAfterSubGraph)) { + return SUCCESS; + } + if (need_repass) { + AddImmediateRePassNode(node); + GELOGD("Node %s need repass immediately.", node->GetName().c_str()); + } else { + // clear attr on while + node->GetOpDesc()->DelAttr("need_infer_again_"); + } + } return SUCCESS; } } // namespace ge diff --git a/ge/hybrid/executor/hybrid_model_async_executor.cc b/ge/hybrid/executor/hybrid_model_async_executor.cc index 9f37e7d5..0194a492 100644 --- a/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -67,6 +67,7 @@ Status HybridModelAsyncExecutor::Start(const std::shared_ptr &lis future_ = std::async(std::launch::async, [&]() -> Status { GetThreadLocalContext() = *executor_->GetContext()->ge_context; GetContext().SetSessionId(executor_->GetContext()->session_id); + GetContext().SetContextId(executor_->GetContext()->context_id); return RunInternal(); }); @@ -166,6 +167,7 @@ Status HybridModelAsyncExecutor::RunInternal() { } else { GELOGI("HybridModel will execute in singleline mode"); ge::GetContext().SetSessionId(executor_->GetContext()->session_id); + ge::GetContext().SetContextId(executor_->GetContext()->context_id); ret = executor_->Execute(args); } ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 45db9936..57e4052d 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -227,6 +227,7 @@ Status SubgraphExecutor::PrepareNodes(int group) { if (node_item.is_dynamic) { auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { GetContext().SetSessionId(context_->session_id); + GetContext().SetContextId(context_->context_id); GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); return PrepareForExecution(context_, *p_node_state); }); @@ -273,10 +274,8 @@ Status SubgraphExecutor::PrepareNodes(int group) { } Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) const { - GetContext().SetSessionId(context_->context_id); HYBRID_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state), "[%s] Failed to InferShape.", node_state.GetName().c_str()); - GetContext().SetSessionId(context_->session_id); HYBRID_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_state), "[%s] Failed to PropagateOutputShapes.", node_state.GetName().c_str()); return SUCCESS; @@ -345,6 +344,7 @@ Status SubgraphExecutor::ScheduleTasks(int group) { GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); auto prepare_future = std::async(std::launch::async, [&]() -> Status { GetContext().SetSessionId(context_->session_id); + GetContext().SetContextId(context_->context_id); auto ret = PrepareNodes(group); ready_queue_.Push(nullptr); return ret; diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 07c2ddb5..6af2fd4a 100644 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -307,11 +307,9 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { auto execution_context = context.GetExecutionContext(); - GetContext().SetSessionId(execution_context->context_id); RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start"); GE_CHK_STATUS_RET(CalcTilingInfo(node, tiling_info)); RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End"); - GetContext().SetSessionId(execution_context->session_id); // update op args by tiling info block_dim_ = static_cast(tiling_info.block_dim);