/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_EXTERNAL_GE_GE_API_H_ #define INC_EXTERNAL_GE_GE_API_H_ #include #include #include #include "ge/ge_api_error_codes.h" #include "ge/ge_api_types.h" #include "graph/graph.h" #include "graph/tensor.h" namespace ge { typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map ¶ms_list); namespace session { typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map ¶ms_list); } // Initialize GE ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY Status GEInitialize(const std::map &)) GE_FUNC_VISIBILITY Status GEInitialize(const std::map &options); GE_FUNC_VISIBILITY Status GEInitialize(const std::map &options); // Finalize GE, release all resources GE_FUNC_VISIBILITY Status GEFinalize(); GE_FUNC_VISIBILITY std::string GEGetErrorMsg(); GE_FUNC_VISIBILITY std::string GEGetWarningMsg(); class GE_FUNC_VISIBILITY Session { public: ATTRIBUTED_DEPRECATED(Session(const std::map &)) explicit Session(const std::map &options); explicit Session(const std::map &options); ~Session(); /// /// @ingroup client /// @brief add a graph with a specific graphId /// @param [in] graphId graph id /// @return Status result of function /// Status AddGraph(uint32_t graphId, const Graph &graph); /// /// @ingroup client /// @brief add a graph with a specific graphId and graphOptions /// @param [in] graphId graph id /// @param [in] graph the graph /// @param [in] options graph options /// @return Status result of function /// ATTRIBUTED_DEPRECATED(Status AddGraph(uint32_t, const Graph &, const std::map &)) Status AddGraph(uint32_t graphId, const Graph &graph, const std::map &options); /// /// @ingroup client /// @brief add a graph with a specific graphId and graphOptions /// @param [in] graphId graph id /// @param [in] graph the graph /// @param [in] options graph options /// @return Status result of function /// Status AddGraph(uint32_t graphId, const Graph &graph, const std::map &options); /// /// @ingroup client /// @brief add a copy graph with a specific graphId /// @param [in] graphId graph id /// @param [in] graph the graph /// @return Status result of function /// Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph); /// /// @ingroup client /// @brief add a copy graph with a specific graphId and graphOptions /// @param [in] graphId graph id /// @param [in] graph the graph /// @param [in] options graph options /// @return Status result of function /// Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map &options); /// /// @ingroup ge_graph /// @brief remove a graph of the session with specific session id /// @param [in] graphId graph id /// @return Status result of function /// Status RemoveGraph(uint32_t graphId); /// /// @ingroup ge_graph /// @brief run a graph of the session with specific session id /// @param [in] graphId graph id /// @param [in] inputs input data /// @param [out] outputs output data /// @return Status result of function /// Status RunGraph(uint32_t graphId, const std::vector &inputs, std::vector &outputs); /// /// @ingroup ge_graph /// @brief build graph in the session with specific session id /// @param [in] graphId: graph id /// @param [in] inputs: input data /// @return Status result of function /// Status BuildGraph(uint32_t graphId, const std::vector &inputs); /// /// @ingroup ge_graph /// @brief run graph in the session with specific session id asynchronously /// @param [in] graphId: graph id /// @param [in] inputs: input data /// @param [out] callback: callback while runing graph has been finished. /// The callback function will not be checked. /// Please ensure that the implementation of the function is trusted. /// @return Status result of function /// Status RunGraphAsync(uint32_t graphId, const std::vector &inputs, RunAsyncCallback callback); /// /// @ingroup ge_graph /// @brief get variables in the session with specific session id /// @param [in] var_names: variable names /// @param [out] var_values: variable values /// @return Status result of function /// ATTRIBUTED_DEPRECATED(Status GetVariables(const std::vector &, std::vector &)) Status GetVariables(const std::vector &var_names, std::vector &var_values); /// /// @ingroup ge_graph /// @brief get variables in the session with specific session id /// @param [in] var_names: variable names /// @param [out] var_values: variable values /// @return Status result of function /// Status GetVariables(const std::vector &var_names, std::vector &var_values); /// /// @ingroup ge_graph /// @brief register callback func with specific summary or checkpoint by users /// @param [in] key: func key /// @param [in] callback: callback specific summary or checkpoint. /// The callback function will not be checked. /// Please ensure that the implementation of the function is trusted. /// @return Status result of function /// ATTRIBUTED_DEPRECATED(Status RegisterCallBackFunc(const char *, const session::pCallBackFunc &)) Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback); Status RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback); bool IsGraphNeedRebuild(uint32_t graphId); private: uint64_t sessionId_; }; } // namespace ge #endif // INC_EXTERNAL_GE_GE_API_H_