parent
28052ad188
commit
89f96e347b
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,76 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_node_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
void SubgraphNodePass::UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
for (auto &idx : subgraph->nodeIndices) {
|
||||
if (idx > node_idx) {
|
||||
idx--;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
std::vector<schema::CNodeT *> new_nodes{};
|
||||
std::transform(graph->nodes.begin(), graph->nodes.end(), std::back_inserter(new_nodes),
|
||||
[](std::unique_ptr<CNodeT> &node) { return node.get(); });
|
||||
|
||||
for (auto it = old_nodes_.begin(); it != old_nodes_.end();) {
|
||||
if (!IsContain(new_nodes, *it)) {
|
||||
size_t node_idx = it - old_nodes_.begin();
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
auto node_idx_pos = std::find(subgraph->nodeIndices.begin(), subgraph->nodeIndices.end(), node_idx);
|
||||
if (node_idx_pos != subgraph->nodeIndices.end()) {
|
||||
subgraph->nodeIndices.erase(node_idx_pos);
|
||||
UpdateSubgraphNodeIndices(node_idx, graph);
|
||||
break;
|
||||
}
|
||||
}
|
||||
it = old_nodes_.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < new_nodes.size(); i++) {
|
||||
if (!IsContain(old_nodes_, new_nodes[i])) {
|
||||
for (auto &subgraph : graph->subGraph) {
|
||||
if (IsContain(subgraph->nodeIndices, i - 1) || IsContain(subgraph->nodeIndices, i + 1)) {
|
||||
subgraph->nodeIndices.push_back(old_nodes_.size());
|
||||
old_nodes_.push_back(new_nodes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H
|
||||
#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_NODE_PASS_H
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "tools/converter/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class SubgraphNodePass : public GraphPass {
|
||||
public:
|
||||
explicit SubgraphNodePass(std::vector<schema::CNodeT *> old_nodes) : old_nodes_(std::move(old_nodes)) {}
|
||||
|
||||
~SubgraphNodePass() override = default;
|
||||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
private:
|
||||
void UpdateSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
|
||||
std::vector<schema::CNodeT *> old_nodes_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H
|
@ -0,0 +1,100 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "tools/converter/legacy_optimizer/graph/subgraph_tensor_pass.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/common/utils.h"
|
||||
#include "tools/common/graph_util.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
|
||||
bool SubgraphTensorPass::IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
|
||||
for (const auto &node : graph->nodes) {
|
||||
if (IsContain<uint32_t>(node->inputIndex, tensor_idx)) {
|
||||
return true;
|
||||
}
|
||||
if (IsContain<uint32_t>(node->outputIndex, tensor_idx)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
STATUS SubgraphTensorPass::UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx) {
|
||||
for (const auto &subgraph : graph->subGraph) {
|
||||
UpdateVec<uint32_t>(&(subgraph->inputIndices), tensor_idx);
|
||||
UpdateVec<uint32_t>(&(subgraph->outputIndices), tensor_idx);
|
||||
}
|
||||
for (const auto &node : graph->nodes) {
|
||||
UpdateVec<uint32_t>(&(node->inputIndex), tensor_idx);
|
||||
UpdateVec<uint32_t>(&(node->outputIndex), tensor_idx);
|
||||
}
|
||||
UpdateVec<uint32_t>(&(graph->inputIndex), tensor_idx);
|
||||
UpdateVec<uint32_t>(&(graph->outputIndex), tensor_idx);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS SubgraphTensorPass::RemoveUselessTensors(schema::MetaGraphT *graph) {
|
||||
for (auto it = graph->allTensors.begin(); it != graph->allTensors.end();) {
|
||||
uint32_t idx = it - graph->allTensors.begin();
|
||||
if (IsUsing(graph, idx)) {
|
||||
it++;
|
||||
} else {
|
||||
it = graph->allTensors.erase(it);
|
||||
UpdateTensorIdx(graph, idx);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS SubgraphTensorPass::SyncMainGraphInputAndOutput(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph->subGraph.size() > 0);
|
||||
graph->subGraph[0]->inputIndices.assign(graph->inputIndex.begin(), graph->inputIndex.end());
|
||||
graph->subGraph[0]->outputIndices.assign(graph->outputIndex.begin(), graph->outputIndex.end());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS SubgraphTensorPass::Run(schema::MetaGraphT *graph) {
|
||||
MS_ASSERT(graph != nullptr);
|
||||
|
||||
int ret = RemoveUselessTensors(graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "RemoveUselessTensors failed, ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = SetSubgraphTensorIndices(graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
ret = SyncMainGraphInputAndOutput(graph);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "SetSubgraphTensorIndices failed, ret: " << ret;
|
||||
return ret;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H
|
||||
#define MINDSPORE_PREDICT_ISOLATED_SUBGRAPH_TENSOR_PASS_H
|
||||
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include "tools/converter/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class SubgraphTensorPass : public GraphPass {
|
||||
public:
|
||||
SubgraphTensorPass() = default;
|
||||
|
||||
~SubgraphTensorPass() override = default;
|
||||
|
||||
STATUS Run(schema::MetaGraphT *graph) override;
|
||||
|
||||
private:
|
||||
STATUS RemoveUselessTensors(schema::MetaGraphT *graph);
|
||||
bool IsUsing(schema::MetaGraphT *graph, const uint32_t &tensor_idx);
|
||||
STATUS UpdateTensorIdx(schema::MetaGraphT *graph, const uint32_t &tensor_idx);
|
||||
STATUS SyncMainGraphInputAndOutput(schema::MetaGraphT *graph);
|
||||
|
||||
template <typename T>
|
||||
void UpdateVec(std::vector<T> *vec, T element) {
|
||||
for (auto iter = vec->begin(); iter != vec->end(); iter++) {
|
||||
if (*iter > element) {
|
||||
(*iter)--;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H
|
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue