/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_SESSION_INNER_SESSION_H_ #define GE_SESSION_INNER_SESSION_H_ #include #include #include #include "framework/common/ge_types.h" #include "ge/ge_api_types.h" #include "graph/manager/graph_manager.h" namespace ge { class InnerSession { public: InnerSession(uint64_t session_id, const std::map &options); ~InnerSession() = default; Status Initialize(); Status AddGraph(uint32_t graph_id, const Graph &graph); Status AddGraph(uint32_t graph_id, const Graph &graph, const std::map &options); Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map &options); Status RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector &outputs); Status RemoveGraph(uint32_t graph_id); Status BuildGraph(uint32_t graph_id, const std::vector &inputs); Status RunGraphAsync(uint32_t graph_id, const std::vector &inputs, RunAsyncCallback callback); Status Finalize(); Status GetAllVariables(std::map &all_variables); Status GenCheckPointGraph(const std::map &all_variables, Graph &graph); Status SaveVariables(const Graph &graph, const std::vector &var_names, const std::vector &outputs, std::vector &var_values); Status GetVariable(const std::string &name, Tensor &val); Status RegisterCallBackFunc( const std::string &key, const std::function &)> &callback); Status RegisterCallBackFunc( const std::string &key, const std::function &)> &callback); const GraphManager &getGraphManagerObj() const; bool IsGraphNeedRebuild(uint32_t graph_id); Status AddDumpProperties(const DumpProperties &dump_properties); Status RemoveDumpProperties(); void SetRtSocVersion(); private: bool init_flag_; uint64_t session_id_; std::map options_; GraphManager graph_manager_; std::mutex resource_mutex_; // AddGraph, RemoveGraph and Finalize use void UpdateThreadContext(const std::map &options); void UpdateThreadContext(uint32_t graph_id); static bool is_dump_server_inited_; }; } // namespace ge #endif // GE_SESSION_INNER_SESSION_H_