/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/save_pass.h" #include #include #include #include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "graph/utils/graph_utils.h" namespace ge { namespace { const char *const kSave = "Save"; const char *const kVar = "Variable"; const char *const kVarIsSave = "save_checkpoint"; const char *const kVarAttrVarIsSave = "_var_is_save"; } // namespace Status SavePass::Run(ge::ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); vector front_nodes; vector out_index; vector del_nodes; for (auto &node : graph->GetDirectNode()) { if (node->GetType() == kSave) { for (auto &in : node->GetAllInDataAnchors()) { auto out_anchor = in->GetPeerOutAnchor(); if (out_anchor != nullptr) { ge::NodePtr peer_node = out_anchor->GetOwnerNode(); if (peer_node->GetType() == kVar) { front_nodes.emplace_back(peer_node); out_index.emplace_back(out_anchor->GetIdx()); ge::OpDescPtr op_desc = peer_node->GetOpDesc(); GE_IF_BOOL_EXEC(!ge::AttrUtils::SetStr(op_desc, kVarAttrVarIsSave, kVarIsSave), GELOGE(INTERNAL_ERROR, "get kVarAttrVarIsSave failed"); return INTERNAL_ERROR); } } } del_nodes.emplace_back(node); } } // add output nodes for save std::vector> out_nodes_info{}; for (size_t i = 0; i < front_nodes.size(); i++) { out_nodes_info.emplace_back(pair(front_nodes[i], out_index[i])); } graph->AppendGraphOutNodesInfo(out_nodes_info); // delete save node for (auto &node_ptr : del_nodes) { auto ret = graph->RemoveNode(node_ptr); if (ret != SUCCESS) { GELOGE(ret, "GraphUtils::RemoveNodeWithoutRelink failed."); return ret; } // update Target list vector graph_target = graph->GetGraphTargetNodesInfo(); auto iter = find(graph_target.begin(), graph_target.end(), node_ptr); if (iter != graph_target.end()) { GELOGI("Current node %s is as Target, remove it from target vector.", node_ptr->GetName().c_str()); graph_target.erase(iter); graph->SetGraphTargetNodesInfo(graph_target); } } return SUCCESS; } } // namespace ge