From dc0edfc5b82150ace42fb87d08e0e1c7f5f3f9ee Mon Sep 17 00:00:00 2001 From: wuweikang Date: Thu, 5 Nov 2020 15:33:26 +0800 Subject: [PATCH] AddGraphWithCopy --- ge/client/ge_api.cc | 27 ++++++++++++ ge/graph/manager/graph_manager.cc | 72 +++++++++++++++++++++++++++++++ ge/graph/manager/graph_manager.h | 10 +++++ ge/session/inner_session.cc | 18 ++++++++ ge/session/inner_session.h | 2 + ge/session/session_manager.cc | 30 +++++++++++++ ge/session/session_manager.h | 14 +++++- inc/external/ge/ge_api.h | 19 ++++++++ 8 files changed, 191 insertions(+), 1 deletion(-) diff --git a/ge/client/ge_api.cc b/ge/client/ge_api.cc index 522985fa..5619f137 100644 --- a/ge/client/ge_api.cc +++ b/ge/client/ge_api.cc @@ -260,6 +260,33 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map options; + return AddGraphWithCopy(graph_id, graph, options); +} + +Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph, + const std::map &options) { + GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_); + std::shared_ptr instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session."); + return FAILED; + } + std::map str_options; + for (auto it = options.begin(); it != options.end(); ++it) { + str_options.insert({it->first.GetString(), it->second.GetString()}); + } + GELOGD("Adding graph to session"); + Status ret = instance_ptr->SessionManagerObj().AddGraphWithCopy(sessionId_, graph_id, graph, str_options); + if (ret != SUCCESS) { + GELOGE(ret, "AddGraph failed in Session."); + return FAILED; + } + GELOGD("AddGraph finished in Session."); + return ret; +} + Status Session::RemoveGraph(uint32_t graph_id) { GELOGT(TRACE_INIT, "Session RemoveGraph start"); diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 4737955d..ec05e598 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -330,6 +330,78 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, return SUCCESS; } +Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph, + const std::map &options, + const OmgContext &omg_context) { + if (HasGraphNode(graph_id)) { + GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id); + return GE_GRAPH_GRAPH_ALREADY_EXIST; + } + auto compute_graph = GraphUtils::GetComputeGraph(graph); + if (compute_graph != nullptr) { + compute_graph->SetGraphID(graph_id); + bool graph_has_been_added = false; + if (AttrUtils::GetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, graph_has_been_added) + && graph_has_been_added) { + GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, + "[GraphManager] same graph object can not be added again, graph_id = %u.", graph_id); + return GE_GRAPH_GRAPH_ALREADY_EXIST; + } + } else { + GELOGE(FAILED, "compute graph is null"); + return FAILED; + } + std::vector input_nodes; + std::vector output_nodes; + auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); + std::string session_graph_id; + if (!AttrUtils::GetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || + session_graph_id.empty()) { + session_graph_id = "-1_" + to_string(graph_id); + if (!AttrUtils::SetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { + GELOGW("Set attribute of compute graph failed."); + } + for (auto &subgraph : new_compute_graph->GetAllSubgraphs()) { + (void)AttrUtils::SetStr(*subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); + } + GELOGW("Get graph session_graph_id attr failed, set session id to default value: [0]"); + } + + GraphNodePtr graph_node = MakeShared(graph_id); + if (graph_node == nullptr) { + GELOGE(FAILED, "GraphNode make shared failed"); + return FAILED; + } + std::shared_ptr graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph); + if (graph_ptr == nullptr) { + GELOGE(FAILED, "GraphPtr make shared failed"); + return FAILED; + } + + graph_node->SetGraph(graph_ptr); + graph_node->SetOptions(options); + AddGraphNode(graph_id, graph_node); + + AddLocalOmgContext(graph_id, omg_context); + if (!options_.output_datatype.empty()) { + GetLocalOmgContext().output_type = options_.output_datatype; + } + + CompilerStages &stages = GetCompilerStages(graph_id); + stages.preparer.SetOptions(options_); + Status status = stages.optimizer.SetOptions(options_); + if (status != SUCCESS) { + GELOGE(status, "Graph optimizer set options failed."); + return status; + } + stages.builder.SetOptions(options_); + + var_acc_ctrl_.AddGraph(graph_id, new_compute_graph); + + GELOGI("[GraphManager] add graph success, graph_id = %u.", graph_id); + return SUCCESS; +} + Status GraphManager::MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph, GraphId root_graph_id) { std::shared_ptr instance_ptr = ge::GELib::GetInstance(); diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index fc3601af..d66f1ce4 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -73,6 +73,16 @@ class GraphManager { Status AddGraph(const GraphId &graph_id, const Graph &graph, const std::map &options, const OmgContext &omg_context); + /// + /// @ingroup ge_graph + /// @brief add a copy graph + /// @param [in] graph_id graph id + /// @param [out] Graph output graph + /// @return Status result of function + /// + Status AddGraphWithCopy(const GraphId &graph_id, const Graph &graph, + const std::map &options, const OmgContext &omg_context); + /// /// @ingroup ge_graph /// @brief remove specific graph diff --git a/ge/session/inner_session.cc b/ge/session/inner_session.cc index aa825a4b..ec85d9ac 100755 --- a/ge/session/inner_session.cc +++ b/ge/session/inner_session.cc @@ -166,6 +166,24 @@ Status InnerSession::AddGraph(uint32_t graph_id, const Graph &graph, return SUCCESS; } +Status InnerSession::AddGraphWithCopy(uint32_t graph_id, const Graph &graph, + const std::map &options) { + std::lock_guard lock(resource_mutex_); + if (!init_flag_) { + GELOGE(GE_SESS_INIT_FAILED, "[InnerSession:%lu] initialize failed.", session_id_); + return GE_SESS_INIT_FAILED; + } + UpdateThreadContext(options); + Status ret = graph_manager_.AddGraphWithCopy(graph_id, graph, options, domi::GetContext()); + if (ret != SUCCESS) { + GELOGE(ret, "[InnerSession:%lu] add graph %u failed.", session_id_, graph_id); + return ret; + } + + GELOGI("[InnerSession:%lu] add graph success, graph_id=%u.", session_id_, graph_id); + return SUCCESS; +} + Status InnerSession::RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector &outputs) { GELOGI("[InnerSession:%lu] run graph on session, graph_id=%u.", session_id_, graph_id); if (mutex_.try_lock()) { diff --git a/ge/session/inner_session.h b/ge/session/inner_session.h index 25f5c307..db7a2c92 100644 --- a/ge/session/inner_session.h +++ b/ge/session/inner_session.h @@ -37,6 +37,8 @@ class InnerSession { Status AddGraph(uint32_t graph_id, const Graph &graph, const std::map &options); + Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map &options); + Status RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector &outputs); Status RemoveGraph(uint32_t graph_id); diff --git a/ge/session/session_manager.cc b/ge/session/session_manager.cc index 6f8c9432..81efb080 100755 --- a/ge/session/session_manager.cc +++ b/ge/session/session_manager.cc @@ -170,6 +170,36 @@ Status SessionManager::AddGraph(SessionId session_id, uint32_t graph_id, const G return innerSession->AddGraph(graph_id, graph, options); } +Status SessionManager::AddGraphWithCopy(SessionId session_id, uint32_t graph_id, const Graph &graph, + const std::map &options) { + if (!init_flag_) { + GELOGE(GE_SESSION_MANAGER_NOT_INIT); + return GE_SESSION_MANAGER_NOT_INIT; + } + SessionPtr innerSession = nullptr; + { + std::lock_guard lock(mutex_); + std::map::iterator it = session_manager_map_.find(session_id); + if (it == session_manager_map_.end()) { + return GE_SESSION_NOT_EXIST; + } else { + innerSession = it->second; + } + auto compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + std::string session_graph_id = std::to_string(session_id) + "_" + std::to_string(graph_id); + if (!AttrUtils::SetStr(*compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { + GELOGW("Set graph session_graph_id attr failed."); + } else { + GELOGD("Set graph session_graph_id attr to [%s]", session_graph_id.c_str()); + } + for (auto graph : compute_graph->GetAllSubgraphs()) { + AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); + } + } + return innerSession->AddGraphWithCopy(graph_id, graph, options); +} + Status SessionManager::RunGraph(SessionId session_id, uint32_t graph_id, const std::vector &inputs, std::vector &outputs) { if (!init_flag_) { diff --git a/ge/session/session_manager.h b/ge/session/session_manager.h index 88864f61..ac901c3a 100644 --- a/ge/session/session_manager.h +++ b/ge/session/session_manager.h @@ -62,7 +62,7 @@ class SessionManager { /// /// @ingroup ge_session - /// @brief add a graph to the session with specific session id + /// @brief add a graph to the session with specific session id and graphOptions /// @param [in] session_id session id /// @param [in] graph_id graph id /// @param [in] graph the graph to add @@ -72,6 +72,18 @@ class SessionManager { Status AddGraph(SessionId session_id, uint32_t graph_id, const Graph &graph, const std::map &options); + /// + /// @ingroup ge_session + /// @brief add a copy graph to the session with specific session id and graphOptions + /// @param [in] session_id session id + /// @param [in] graph_id graph id + /// @param [in] graph the graph to add + /// @param [in] options graph level options + /// @return Status result of function + /// + Status AddGraphWithCopy(SessionId session_id, uint32_t graph_id, const Graph &graph, + const std::map &options); + /// /// @ingroup ge_session /// @brief run a graph of the session with specific session id diff --git a/inc/external/ge/ge_api.h b/inc/external/ge/ge_api.h index b4b9bb2a..8fd4b944 100644 --- a/inc/external/ge/ge_api.h +++ b/inc/external/ge/ge_api.h @@ -59,6 +59,25 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { /// Status AddGraph(uint32_t graphId, const Graph &graph, const std::map &options); + /// + /// @ingroup client + /// @brief add a copy graph with a specific graphId + /// @param [in] graphId graph id + /// @param [in] graph the graph + /// @return Status result of function + /// + Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph); + + /// + /// @ingroup client + /// @brief add a copy graph with a specific graphId and graphOptions + /// @param [in] graphId graph id + /// @param [in] graph the graph + /// @param [in] options graph options + /// @return Status result of function + /// + Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map &options); + /// /// @ingroup ge_graph /// @brief remove a graph of the session with specific session id