You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
6.5 KiB
200 lines
6.5 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef GE_SESSION_SESSION_MANAGER_H_
|
|
#define GE_SESSION_SESSION_MANAGER_H_
|
|
|
|
#include <iostream>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
#include "common/ge_inner_error_codes.h"
|
|
#include "ge/ge_api_types.h"
|
|
#include "session/inner_session.h"
|
|
|
|
namespace ge {
|
|
using SessionPtr = std::shared_ptr<InnerSession>;
|
|
|
|
class SessionManager {
|
|
friend class GELib;
|
|
|
|
public:
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief create session
|
|
/// @param [in] options session config options
|
|
/// @param [out] session_id session id
|
|
/// @return Status result of function
|
|
///
|
|
Status CreateSession(const std::map<std::string, std::string> &options, SessionId &session_id);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief destroy the session with specific session id
|
|
/// @param [in] session_id session id
|
|
/// @return Status result of function
|
|
///
|
|
Status DestroySession(SessionId session_id);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief add a graph to the session with specific session id
|
|
/// @param [in] session_id session id
|
|
/// @param [in] graph_id graph id
|
|
/// @param [in] graph the graph to add
|
|
/// @return Status result of function
|
|
///
|
|
Status AddGraph(SessionId session_id, uint32_t graph_id, const ge::Graph &graph);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief add a graph to the session with specific session id and graphOptions
|
|
/// @param [in] session_id session id
|
|
/// @param [in] graph_id graph id
|
|
/// @param [in] graph the graph to add
|
|
/// @param [in] options graph level options
|
|
/// @return Status result of function
|
|
///
|
|
Status AddGraph(SessionId session_id, uint32_t graph_id, const Graph &graph,
|
|
const std::map<std::string, std::string> &options);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief add a copy graph to the session with specific session id and graphOptions
|
|
/// @param [in] session_id session id
|
|
/// @param [in] graph_id graph id
|
|
/// @param [in] graph the graph to add
|
|
/// @param [in] options graph level options
|
|
/// @return Status result of function
|
|
///
|
|
Status AddGraphWithCopy(SessionId session_id, uint32_t graph_id, const Graph &graph,
|
|
const std::map<std::string, std::string> &options);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief run a graph of the session with specific session id
|
|
/// @param [in] session_id session id
|
|
/// @param [in] graph_id graph id
|
|
/// @param [in] inputs input data
|
|
/// @param [out] outputs output data
|
|
/// @return Status result of function
|
|
///
|
|
Status RunGraph(SessionId session_id, uint32_t graph_id, const std::vector<Tensor> &inputs,
|
|
std::vector<Tensor> &outputs);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief remove a graph from the session with specific session id
|
|
/// @param [in] session_id session id
|
|
/// @param [in] graph_id graph id
|
|
/// @return Status result of function
|
|
///
|
|
Status RemoveGraph(SessionId session_id, uint32_t graph_id);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief get variable value from the session with specific session id
|
|
/// @param [in] session_id session id
|
|
/// @param [in] name op name
|
|
/// @param [out] val out value tensor
|
|
/// @return Status result of function
|
|
///
|
|
Status GetVariable(SessionId session_id, const std::string &name, Tensor &val);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief build a graph of the session with specific session id
|
|
/// @param [in] session_id session id
|
|
/// @param [in] graph_id graph id
|
|
/// @param [in] inputs input data
|
|
/// @return Status result of function
|
|
///
|
|
Status BuildGraph(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief run a graph of the session with specific session id for train asynchronously
|
|
/// @param [in] session_id session id
|
|
/// @param [in] graph_id graph id
|
|
/// @param [in] inputs input data
|
|
/// @return Status result of function
|
|
///
|
|
Status RunGraphAsync(SessionId session_id, uint32_t graph_id, const std::vector<InputTensorInfo> &inputs,
|
|
RunAsyncCallback callback);
|
|
|
|
///
|
|
/// @ingroup ge_graph
|
|
/// @brief get variables in the session with specific session id
|
|
/// @param [in] session_id: sssion id
|
|
/// @param [in] var_names: variable names
|
|
/// @param [out] var_values: variable values
|
|
/// @return Status result of function
|
|
///
|
|
Status GetVariables(SessionId session_id, const std::vector<std::string> &var_names,
|
|
std::vector<Tensor> &var_values);
|
|
|
|
///
|
|
/// @ingroup ge_graph
|
|
/// @brief me register the callback function to get the result of summary or checkpoin
|
|
/// @param [in] session_id session id
|
|
/// @param [in] key: summary or checkpoint
|
|
/// @param [in] callbak: The real callback object of me
|
|
/// @return Status result of function
|
|
///
|
|
Status RegisterCallBackFunc(
|
|
SessionId session_id, const std::string &key,
|
|
const std::function<Status(uint32_t, const std::map<std::string, ge::Tensor> &)> &callback);
|
|
Status RegisterCallBackFunc(
|
|
SessionId session_id, const std::string &key,
|
|
const std::function<Status(uint32_t, const std::map<AscendString, ge::Tensor> &)> &callback);
|
|
|
|
bool IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id);
|
|
|
|
private:
|
|
SessionManager() = default;
|
|
|
|
~SessionManager() = default;
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief initialize session manager
|
|
/// @param [in] options session manager config options
|
|
/// @return Status result of function
|
|
///
|
|
Status Initialize(const std::map<std::string, std::string> &options);
|
|
|
|
///
|
|
/// @ingroup ge_session
|
|
/// @brief finalize session manager
|
|
/// @return Status result of function
|
|
///
|
|
Status Finalize();
|
|
|
|
bool HasSession(SessionId session_id);
|
|
|
|
Status GetNextSessionId(SessionId &next_session_id);
|
|
|
|
Status SetRtContext(SessionId session_id, rtContext_t rtContext);
|
|
|
|
std::map<SessionId, SessionPtr> session_manager_map_;
|
|
std::mutex mutex_;
|
|
bool init_flag_ = false;
|
|
};
|
|
} // namespace ge
|
|
|
|
#endif // GE_SESSION_SESSION_MANAGER_H_
|