You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.0 KiB
102 lines
3.0 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "graph/passes/assert_pass.h"
|
|
|
|
#include <map>
|
|
#include <queue>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
#include "framework/common/ge_inner_error_codes.h"
|
|
#include "framework/common/util.h"
|
|
|
|
namespace ge {
|
|
// aicpu not support string type, so current implemention is Upward traversal
|
|
Status AssertPass::Run(NodePtr &node) {
|
|
GELOGD("AssertPass running");
|
|
if (node == nullptr) {
|
|
GELOGE(PARAM_INVALID, "param [node] must not be null.");
|
|
return PARAM_INVALID;
|
|
}
|
|
if (node->GetOpDesc() == nullptr) {
|
|
GELOGE(PARAM_INVALID, "param [node] [opDesc] must not be null.");
|
|
return PARAM_INVALID;
|
|
}
|
|
std::string op_type = node->GetOpDesc()->GetType();
|
|
if (op_type == ASSERT) {
|
|
GELOGD("op type is assert.");
|
|
|
|
std::vector<NodePtr> nodes_unused;
|
|
// collect assert and other unused ops
|
|
CollectUnusedNode(node, nodes_unused);
|
|
// remove unused node
|
|
Status status = RemoveUnusedNode(nodes_unused);
|
|
if (status != SUCCESS) {
|
|
GELOGE(status, "remove unused node failed.");
|
|
return status;
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
void AssertPass::CollectUnusedNode(const NodePtr &assert_node, vector<NodePtr> &nodes_unused) {
|
|
std::map<Node *, uint32_t> invalid_outdata_info;
|
|
std::queue<NodePtr> node_queue;
|
|
node_queue.push(assert_node);
|
|
|
|
while (!node_queue.empty()) {
|
|
NodePtr cur_node = node_queue.front();
|
|
if (cur_node == nullptr) {
|
|
continue;
|
|
}
|
|
node_queue.pop();
|
|
nodes_unused.push_back(cur_node);
|
|
|
|
for (const auto &src_node : cur_node->GetInDataNodes()) {
|
|
if (src_node != nullptr && src_node->GetOpDesc() != nullptr) {
|
|
auto size = ++invalid_outdata_info[src_node.get()];
|
|
// src_node need to be deleted
|
|
if (src_node->GetOutDataNodesSize() == size && src_node->GetOpDesc()->GetType() != DATA &&
|
|
src_node->GetOpDesc()->GetType() != AIPPDATA) {
|
|
node_queue.push(src_node);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Status AssertPass::RemoveUnusedNode(std::vector<NodePtr> &nodes_unused) {
|
|
for (NodePtr &node : nodes_unused) {
|
|
if (node == nullptr) {
|
|
continue;
|
|
}
|
|
std::vector<int> assert_io_map;
|
|
size_t out_nums = node->GetAllOutDataAnchorsSize();
|
|
while (out_nums > 0) {
|
|
assert_io_map.push_back(-1);
|
|
out_nums--;
|
|
}
|
|
|
|
if (IsolateAndDeleteNode(node, assert_io_map) != SUCCESS) {
|
|
return FAILED;
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
} // namespace ge
|