/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ #define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ #include #include #include #include "ge/ge_api_error_codes.h" #include "graph/compute_graph.h" #include "graph/manager/graph_var_manager.h" #include "model/ge_model.h" namespace ge { using Json = nlohmann::json; struct CacheInfo { size_t node_num; size_t edge_num; size_t graph_hash; map nodes_hash; CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} }; class ModelCacheHelper { public: ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); ~ModelCacheHelper(); Status SaveCacheInfoToCache () const; Status SaveVarManagerToCache(bool before_build) const; Status SaveOmModelToCache(const GeModelPtr &ge_model) const; bool IsModelCacheHit() const; Status RecoverVarManagerFromCache() const; Status LoadOmModelFromCache(GeModelPtr &ge_model) const; Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); Status ClearCache(uint32_t graph_id) const; private: Status GetComputeGraphHash(size_t &hash) const; Status GetNodesHash(map &hash_map) const; Status GetCacheInfo(CacheInfo &cache_info) const; Status RecoverMemResource(const Json &json) const; Status RecoverAllocatedGraphId(const Json &json) const; Status RecoverChangedGraphId(const Json &json) const; Status RecoverVarAddrAndTensorDesc(const Json &json) const; Status RecoverBroadcastInfo(const Json &json) const; Status RecoverTransRoads(const Json &json) const; static Status GetNodesNeedRecompile(ComputeGraphPtr &graph, vector &nodes); static Status RecompileNodes(GeModelPtr &ge_model); bool IsNodeHashSameAsCache(const map &hash_map) const; bool IsMemResourceSameAsCache(Json &json) const; bool IsChangedGraphIdSameAsCache(Json &json) const; bool IsAllocatedGraphIdSameAsCache(Json &json) const; bool IsCurVarTensorDescSameAsCache(Json &json) const; bool IsVarAddrMgrMapSameAsCache(Json &json) const; bool IsBroadcastInfoSameAsCache(Json &json) const; bool IsTransRoadsSameAsCache(Json &json) const; bool IsVarManagerSameAsCache(Json &json) const; bool IsVarManagerParamSameAsCache(Json &json) const; Status SaveJsonToFile(const string &file_name, const Json &json) const; Status LoadJsonFromFile(const string &file_name, Json &json) const; Status GetNodesHashMapJson(Json &json) const; Status GetMemResourceMap(Json &json) const; Status GetVarAddrMgrMapJson(Json &json) const; Status GetCurVarTensorDescMapJson(Json &json) const; Status GetTransRoadsJson(Json &json) const; Status GetChangedGraphIdJson(Json &json) const; Status GetAllocatedGraphIdJson(Json &json) const; Status GetBroadcastInfoJson(Json &json) const; Status GetVarResourceJson(Json &json) const; Status GetVarManagerJson(Json &json) const; static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); static Status ParseMemResourceFromJson(const Json &json, map &mem_resource); static Status ParseVarAddrMgrMapFromJson(const Json &json, std::vector> &var_addr_mgr_vector, std::set &var_offset_set); static Status ParseCurVarTensorDescMapFromJson( const Json &json, std::unordered_map &cur_var_tensor_desc_map); static Status ParseTransRoadsFromJson(const Json &json, std::unordered_map> &trans_roads); static Status ParseChangedGraphIdFromJson(const Json &json, std::map &changed_graph_id); static Status ParseAllocatedGraphIdFromJson(const Json &json, std::map &allocated_graph_id); static Status ParseBroadcastInfoFromJson(const Json &json, std::unordered_map &var_broadcast_info); static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); uint64_t session_id_; uint32_t graph_id_; string cache_path_; ComputeGraphPtr compute_graph_; std::set var_names_; bool is_cache_path_valid_for_output; static map graph_id_run_times_; }; using ModelCacheHelperPtr = std::shared_ptr; } // namespace ge #endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_