|
|
|
@ -28,30 +28,39 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
|
|
|
|
|
std::set<uint32_t> SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
|
|
|
|
|
schema::MetaGraphT *graph) {
|
|
|
|
|
std::set<uint32_t> tensors_indices{};
|
|
|
|
|
STATUS SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
|
|
|
|
|
schema::MetaGraphT *graph, std::set<uint32_t> *tensors_indices) {
|
|
|
|
|
for (auto &node_idx : subgraph->nodeIndices) {
|
|
|
|
|
if (node_idx >= graph->nodes.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "node_idx: " << node_idx << " bigger than graph->nodes.size(): " << graph->nodes.size();
|
|
|
|
|
for (auto &subgraph : graph->subGraph) {
|
|
|
|
|
MS_LOG(ERROR) << subgraph->name << " : " << subgraph->nodeIndices;
|
|
|
|
|
}
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto &node = graph->nodes.at(node_idx);
|
|
|
|
|
for (auto &input_idx : node->inputIndex) {
|
|
|
|
|
tensors_indices.insert(input_idx);
|
|
|
|
|
tensors_indices->insert(input_idx);
|
|
|
|
|
}
|
|
|
|
|
for (auto &output_idx : node->outputIndex) {
|
|
|
|
|
tensors_indices.insert(output_idx);
|
|
|
|
|
tensors_indices->insert(output_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return tensors_indices;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SubgraphNodePass::IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
|
|
|
|
|
bool SubgraphNodePass::IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices,
|
|
|
|
|
const std::unique_ptr<CNodeT> &node,
|
|
|
|
|
const std::unique_ptr<SubGraphT> &subgraph) {
|
|
|
|
|
return (std::any_of(node->inputIndex.begin(), node->inputIndex.end(),
|
|
|
|
|
[&tensors_indices, &subgraph](uint32_t idx) {
|
|
|
|
|
return tensors_indices.count(idx) > 0 || IsContain(subgraph->inputIndices, idx);
|
|
|
|
|
})) &&
|
|
|
|
|
(std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
|
|
|
|
|
return tensors_indices.count(idx) > 0 || IsContain(subgraph->outputIndices, idx);
|
|
|
|
|
}));
|
|
|
|
|
return std::any_of(node->inputIndex.begin(), node->inputIndex.end(),
|
|
|
|
|
[&tensors_indices, &subgraph](uint32_t idx) { return tensors_indices.count(idx) > 0; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool SubgraphNodePass::IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices,
|
|
|
|
|
const std::unique_ptr<CNodeT> &node,
|
|
|
|
|
const std::unique_ptr<SubGraphT> &subgraph) {
|
|
|
|
|
return std::any_of(node->outputIndex.begin(), node->outputIndex.end(),
|
|
|
|
|
[&tensors_indices, &subgraph](uint32_t idx) { return tensors_indices.count(idx) > 0; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
|
|
|
|
@ -104,12 +113,42 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
|
|
|
|
|
for (uint32_t i = 0; i < new_nodes.size(); i++) {
|
|
|
|
|
if (!IsContain(old_nodes_, new_nodes[i])) {
|
|
|
|
|
auto &node = graph->nodes.at(i);
|
|
|
|
|
std::vector<SubGraphT *> contain_node_input_subgraphs{};
|
|
|
|
|
std::vector<SubGraphT *> contain_node_output_subgraphs{};
|
|
|
|
|
for (auto &subgraph : graph->subGraph) {
|
|
|
|
|
auto tensors_indices = GetSubgraphAllTensorIndices(subgraph, graph);
|
|
|
|
|
if (IsNodeInSubgraph(tensors_indices, node, subgraph)) {
|
|
|
|
|
std::set<uint32_t> tensors_indices{};
|
|
|
|
|
int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "GetSubgraphAllTensorIndices failed.";
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
if (IsNodeInputInSubgraph(tensors_indices, node, subgraph)) {
|
|
|
|
|
contain_node_input_subgraphs.push_back(subgraph.get());
|
|
|
|
|
}
|
|
|
|
|
if (IsNodeOutputInSubgraph(tensors_indices, node, subgraph)) {
|
|
|
|
|
contain_node_output_subgraphs.push_back(subgraph.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 &&
|
|
|
|
|
contain_node_output_subgraphs[0] != contain_node_input_subgraphs[0]) {
|
|
|
|
|
MS_LOG(ERROR) << "not support single node index insert.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 &&
|
|
|
|
|
contain_node_output_subgraphs[0] == contain_node_input_subgraphs[0]) {
|
|
|
|
|
IncreaseSubgraphNodeIndices(i, graph);
|
|
|
|
|
contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (contain_node_input_subgraphs.size() == 1) {
|
|
|
|
|
IncreaseSubgraphNodeIndices(i, graph);
|
|
|
|
|
subgraph->nodeIndices.push_back(i);
|
|
|
|
|
contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (contain_node_output_subgraphs.size() == 1) {
|
|
|
|
|
IncreaseSubgraphNodeIndices(i, graph);
|
|
|
|
|
contain_node_output_subgraphs[0]->nodeIndices.push_back(i);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|