You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
5.2 KiB
124 lines
5.2 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_
|
|
#define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_
|
|
|
|
#include <nlohmann/json.hpp>
|
|
#include <set>
|
|
#include <string>
|
|
|
|
#include "ge/ge_api_error_codes.h"
|
|
#include "graph/compute_graph.h"
|
|
#include "graph/manager/graph_var_manager.h"
|
|
#include "model/ge_model.h"
|
|
|
|
namespace ge {
|
|
using Json = nlohmann::json;
|
|
|
|
struct CacheInfo {
|
|
size_t node_num;
|
|
size_t edge_num;
|
|
size_t graph_hash;
|
|
map<std::string, size_t> nodes_hash;
|
|
CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {}
|
|
};
|
|
|
|
class ModelCacheHelper {
|
|
public:
|
|
ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph);
|
|
~ModelCacheHelper();
|
|
|
|
Status SaveCacheInfoToCache () const;
|
|
Status SaveVarManagerToCache(bool before_build) const;
|
|
Status SaveOmModelToCache(const GeModelPtr &ge_model) const;
|
|
bool IsModelCacheHit() const;
|
|
Status RecoverVarManagerFromCache() const;
|
|
Status LoadOmModelFromCache(GeModelPtr &ge_model) const;
|
|
Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph);
|
|
Status ClearCache(uint32_t graph_id) const;
|
|
|
|
private:
|
|
Status GetComputeGraphHash(size_t &hash) const;
|
|
Status GetNodesHash(map<std::string, size_t> &hash_map) const;
|
|
Status GetCacheInfo(CacheInfo &cache_info) const;
|
|
|
|
Status RecoverMemResource(const Json &json) const;
|
|
Status RecoverAllocatedGraphId(const Json &json) const;
|
|
Status RecoverChangedGraphId(const Json &json) const;
|
|
Status RecoverVarAddrAndTensorDesc(const Json &json) const;
|
|
Status RecoverBroadcastInfo(const Json &json) const;
|
|
Status RecoverTransRoads(const Json &json) const;
|
|
static Status GetNodesNeedRecompile(ComputeGraphPtr &graph, vector<NodePtr> &nodes);
|
|
static Status RecompileNodes(GeModelPtr &ge_model);
|
|
|
|
bool IsNodeHashSameAsCache(const map<std::string, size_t> &hash_map) const;
|
|
bool IsMemResourceSameAsCache(Json &json) const;
|
|
bool IsChangedGraphIdSameAsCache(Json &json) const;
|
|
bool IsAllocatedGraphIdSameAsCache(Json &json) const;
|
|
bool IsCurVarTensorDescSameAsCache(Json &json) const;
|
|
bool IsVarAddrMgrMapSameAsCache(Json &json) const;
|
|
bool IsBroadcastInfoSameAsCache(Json &json) const;
|
|
bool IsTransRoadsSameAsCache(Json &json) const;
|
|
bool IsVarManagerSameAsCache(Json &json) const;
|
|
bool IsVarManagerParamSameAsCache(Json &json) const;
|
|
|
|
Status SaveJsonToFile(const string &file_name, const Json &json) const;
|
|
Status LoadJsonFromFile(const string &file_name, Json &json) const;
|
|
|
|
Status GetNodesHashMapJson(Json &json) const;
|
|
Status GetMemResourceMap(Json &json) const;
|
|
Status GetVarAddrMgrMapJson(Json &json) const;
|
|
Status GetCurVarTensorDescMapJson(Json &json) const;
|
|
Status GetTransRoadsJson(Json &json) const;
|
|
Status GetChangedGraphIdJson(Json &json) const;
|
|
Status GetAllocatedGraphIdJson(Json &json) const;
|
|
Status GetBroadcastInfoJson(Json &json) const;
|
|
Status GetVarResourceJson(Json &json) const;
|
|
Status GetVarManagerJson(Json &json) const;
|
|
|
|
static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json);
|
|
static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc);
|
|
static Status ParseMemResourceFromJson(const Json &json, map<rtMemType_t, int64_t> &mem_resource);
|
|
static Status ParseVarAddrMgrMapFromJson(const Json &json,
|
|
std::vector<std::pair<std::string, VarAddrMgr>> &var_addr_mgr_vector,
|
|
std::set<uint64_t> &var_offset_set);
|
|
static Status ParseCurVarTensorDescMapFromJson(
|
|
const Json &json, std::unordered_map<std::string, ge::GeTensorDesc> &cur_var_tensor_desc_map);
|
|
static Status ParseTransRoadsFromJson(const Json &json,
|
|
std::unordered_map<std::string, std::vector<TransNodeInfo>> &trans_roads);
|
|
static Status ParseChangedGraphIdFromJson(const Json &json,
|
|
std::map<std::string, uint32_t> &changed_graph_id);
|
|
static Status ParseAllocatedGraphIdFromJson(const Json &json,
|
|
std::map<std::string, uint32_t> &allocated_graph_id);
|
|
static Status ParseBroadcastInfoFromJson(const Json &json,
|
|
std::unordered_map<std::string, VarBroadCastInfo> &var_broadcast_info);
|
|
static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name);
|
|
|
|
uint64_t session_id_;
|
|
uint32_t graph_id_;
|
|
string cache_path_;
|
|
ComputeGraphPtr compute_graph_;
|
|
std::set<string> var_names_;
|
|
bool is_cache_path_valid_for_output;
|
|
static map<uint32_t, uint32_t> graph_id_run_times_;
|
|
};
|
|
|
|
using ModelCacheHelperPtr = std::shared_ptr<ModelCacheHelper>;
|
|
} // namespace ge
|
|
|
|
#endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_
|