|
|
|
/**
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include "graph/passes/switch_logic_remove_pass.h"
|
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
|
|
|
#include <utility>
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
#include "graph/utils/graph_utils.h"
|
|
|
|
#include "graph/passes/pass_utils.h"
|
|
|
|
#include "common/util.h"
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
namespace {
|
|
|
|
using PredNodeAndOut = std::pair<NodePtr, int>;
|
|
|
|
constexpr int kSwitchOutputNum = 2;
|
|
|
|
constexpr int kSwitchPredIndex = 1;
|
|
|
|
|
|
|
|
char const *GetOutputNameFromIndex(int index) {
|
|
|
|
if ((index >= 0) && (index < kSwitchOutputNum)) {
|
|
|
|
static char const *name[kSwitchOutputNum] = {"false", "true"};
|
|
|
|
return name[index];
|
|
|
|
}
|
|
|
|
return "UNKNOWN";
|
|
|
|
}
|
|
|
|
|
|
|
|
inline bool IsSwitch(const std::string &type) {
|
|
|
|
return type == SWITCH || type == REFSWITCH;
|
|
|
|
}
|
|
|
|
|
|
|
|
Status GetPredNode(const NodePtr &switch_node, PredNodeAndOut &pred_node_index) {
|
|
|
|
GE_CHECK_NOTNULL(switch_node);
|
|
|
|
auto pred_in_anchor = switch_node->GetInDataAnchor(kSwitchPredIndex);
|
|
|
|
if (pred_in_anchor == nullptr) {
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to get pred node for switch %s, no pred anchor", switch_node->GetName().c_str());
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
auto pred_node_anchor = pred_in_anchor->GetPeerOutAnchor();
|
|
|
|
if (pred_node_anchor == nullptr) {
|
|
|
|
GELOGE(INTERNAL_ERROR,
|
|
|
|
"Failed to get pred node for switch %s, node peer out anchor",
|
|
|
|
switch_node->GetName().c_str());
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
auto pred_node = pred_node_anchor->GetOwnerNode();
|
|
|
|
if (pred_node == nullptr) {
|
|
|
|
GELOGE(INTERNAL_ERROR,
|
|
|
|
"Failed to get pred node for switch %s, null node",
|
|
|
|
switch_node->GetName().c_str());
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
pred_node_index.first = pred_node;
|
|
|
|
pred_node_index.second = pred_node_anchor->GetIdx();
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
Status SwitchLogicRemovePass::Run(NodePtr &node) {
|
|
|
|
GE_CHECK_NOTNULL(node);
|
|
|
|
if (!IsSwitch(node->GetType())) {
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
PredNodeAndOut pred_node_and_out;
|
|
|
|
auto ret = GetPredNode(node, pred_node_and_out);
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to run switch logic remove pass, no pred node found from switch %s",
|
|
|
|
node->GetName().c_str());
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (int i = 0; i < kSwitchOutputNum; ++i) {
|
|
|
|
auto out_anchor = node->GetOutDataAnchor(i);
|
|
|
|
if (out_anchor == nullptr) {
|
|
|
|
GELOGW("Unexpected switch node, the %d out anchor is null", i);
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) {
|
|
|
|
if (in_anchor == nullptr) {
|
|
|
|
GELOGE(INTERNAL_ERROR, "The in-anchor from out anchor %d node %s is null", i, node->GetName().c_str());
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
auto dst_node = in_anchor->GetOwnerNode();
|
|
|
|
if (dst_node == nullptr) {
|
|
|
|
GELOGE(INTERNAL_ERROR, "The peer node from out anchor %d node %s is null", i, node->GetName().c_str());
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
if (!IsSwitch(dst_node->GetType())) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
PredNodeAndOut pred_node_next_switch;
|
|
|
|
ret = GetPredNode(dst_node, pred_node_next_switch);
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to run switch logic remove pass, no pred node found from switch %s",
|
|
|
|
dst_node->GetName().c_str());
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
if (pred_node_and_out != pred_node_next_switch) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
GELOGI("The switch nodes cascaded %s and %s have the save pred node %s, the %s can be remove",
|
|
|
|
node->GetName().c_str(), dst_node->GetName().c_str(),
|
|
|
|
pred_node_and_out.first->GetName().c_str(), dst_node->GetName().c_str());
|
|
|
|
ret = RemoveSwitchNodeLogically(i, dst_node);
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
Status SwitchLogicRemovePass::RemoveSwitchNodeLogically(int parent_index, NodePtr &switch_node) {
|
|
|
|
std::vector<int> isolate_map({-1, -1});
|
|
|
|
for (int i = 0; i < kSwitchOutputNum; ++i) {
|
|
|
|
if (i == parent_index) {
|
|
|
|
isolate_map[i] = 0;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
GE_CHECK_NOTNULL(switch_node);
|
|
|
|
auto out_anchor = switch_node->GetOutDataAnchor(i);
|
|
|
|
if (out_anchor == nullptr) {
|
|
|
|
GELOGW("The switch removing %s does not has %d out anchor, ignore it", switch_node->GetName().c_str(), i);
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
|
|
|
|
GELOGI("Remove inactivate branch %s(%d) from switch %s",
|
|
|
|
GetOutputNameFromIndex(i), i, switch_node->GetName().c_str());
|
|
|
|
std::vector<NodePtr> deleted_nodes;
|
|
|
|
std::vector<NodePtr> end_nodes;
|
|
|
|
auto ret = PassUtils::RemoveInactiveBranchToMerge(out_anchor, deleted_nodes, end_nodes);
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
|
|
|
|
for (auto &node : deleted_nodes) {
|
|
|
|
GE_CHECK_NOTNULL(node);
|
|
|
|
GELOGD("Remove node %s from inactivate branch from switch %s",
|
|
|
|
node->GetName().c_str(), switch_node->GetName().c_str());
|
|
|
|
AddNodeDeleted(node);
|
|
|
|
}
|
|
|
|
for (auto &node : end_nodes) {
|
|
|
|
GE_CHECK_NOTNULL(node);
|
|
|
|
GELOGD("Add end node %s to re-pass list, for inactivate branch from switch %s",
|
|
|
|
node->GetName().c_str(), switch_node->GetName().c_str());
|
|
|
|
AddRePassNode(node);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
GELOGI("Remove switch node cascaded %s, replace out index %d",
|
|
|
|
switch_node->GetName().c_str(), parent_index);
|
|
|
|
return IsolateAndDeleteNode(switch_node, isolate_map);
|
|
|
|
}
|
|
|
|
} // namespace ge
|
|
|
|
|