!10438 [MS][LITE] fix bug of switch pass

From: @mengyuanli
Reviewed-by: 
Signed-off-by:
pull/10438/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ebf055a472

@ -1 +0,0 @@
decoder_step_201217.pb 5

@ -28,30 +28,39 @@
namespace mindspore {
namespace lite {
std::set<uint32_t> SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
schema::MetaGraphT *graph) {
std::set<uint32_t> tensors_indices{};
STATUS SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph,
schema::MetaGraphT *graph, std::set<uint32_t> *tensors_indices) {
for (auto &node_idx : subgraph->nodeIndices) {
if (node_idx >= graph->nodes.size()) {
MS_LOG(ERROR) << "node_idx: " << node_idx << " bigger than graph->nodes.size(): " << graph->nodes.size();
for (auto &subgraph : graph->subGraph) {
MS_LOG(ERROR) << subgraph->name << " : " << subgraph->nodeIndices;
}
return RET_ERROR;
}
auto &node = graph->nodes.at(node_idx);
for (auto &input_idx : node->inputIndex) {
tensors_indices.insert(input_idx);
tensors_indices->insert(input_idx);
}
for (auto &output_idx : node->outputIndex) {
tensors_indices.insert(output_idx);
tensors_indices->insert(output_idx);
}
}
return tensors_indices;
return RET_OK;
}
bool SubgraphNodePass::IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices,
const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph) {
return std::any_of(node->inputIndex.begin(), node->inputIndex.end(),
[&tensors_indices, &subgraph](uint32_t idx) { return tensors_indices.count(idx) > 0; });
}
bool SubgraphNodePass::IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph) {
return (std::any_of(node->inputIndex.begin(), node->inputIndex.end(),
[&tensors_indices, &subgraph](uint32_t idx) {
return tensors_indices.count(idx) > 0 || IsContain(subgraph->inputIndices, idx);
})) &&
(std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) {
return tensors_indices.count(idx) > 0 || IsContain(subgraph->outputIndices, idx);
}));
bool SubgraphNodePass::IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices,
const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph) {
return std::any_of(node->outputIndex.begin(), node->outputIndex.end(),
[&tensors_indices, &subgraph](uint32_t idx) { return tensors_indices.count(idx) > 0; });
}
void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) {
@ -104,12 +113,42 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) {
for (uint32_t i = 0; i < new_nodes.size(); i++) {
if (!IsContain(old_nodes_, new_nodes[i])) {
auto &node = graph->nodes.at(i);
std::vector<SubGraphT *> contain_node_input_subgraphs{};
std::vector<SubGraphT *> contain_node_output_subgraphs{};
for (auto &subgraph : graph->subGraph) {
auto tensors_indices = GetSubgraphAllTensorIndices(subgraph, graph);
if (IsNodeInSubgraph(tensors_indices, node, subgraph)) {
IncreaseSubgraphNodeIndices(i, graph);
subgraph->nodeIndices.push_back(i);
std::set<uint32_t> tensors_indices{};
int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices);
if (ret != RET_OK) {
MS_LOG(ERROR) << "GetSubgraphAllTensorIndices failed.";
return ret;
}
if (IsNodeInputInSubgraph(tensors_indices, node, subgraph)) {
contain_node_input_subgraphs.push_back(subgraph.get());
}
if (IsNodeOutputInSubgraph(tensors_indices, node, subgraph)) {
contain_node_output_subgraphs.push_back(subgraph.get());
}
}
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 &&
contain_node_output_subgraphs[0] != contain_node_input_subgraphs[0]) {
MS_LOG(ERROR) << "not support single node index insert.";
return RET_ERROR;
}
if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 &&
contain_node_output_subgraphs[0] == contain_node_input_subgraphs[0]) {
IncreaseSubgraphNodeIndices(i, graph);
contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
continue;
}
if (contain_node_input_subgraphs.size() == 1) {
IncreaseSubgraphNodeIndices(i, graph);
contain_node_input_subgraphs[0]->nodeIndices.push_back(i);
continue;
}
if (contain_node_output_subgraphs.size() == 1) {
IncreaseSubgraphNodeIndices(i, graph);
contain_node_output_subgraphs[0]->nodeIndices.push_back(i);
continue;
}
}
}

@ -36,9 +36,12 @@ class SubgraphNodePass : public GraphPass {
private:
void DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
void IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph);
std::set<uint32_t> GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph, schema::MetaGraphT *graph);
bool IsNodeInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph);
STATUS GetSubgraphAllTensorIndices(const std::unique_ptr<SubGraphT> &subgraph, schema::MetaGraphT *graph,
std::set<uint32_t> *tensors_indices);
bool IsNodeInputInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph);
bool IsNodeOutputInSubgraph(const std::set<uint32_t> &tensors_indices, const std::unique_ptr<CNodeT> &node,
const std::unique_ptr<SubGraphT> &subgraph);
std::vector<schema::CNodeT *> old_nodes_;
};
} // namespace lite

@ -50,13 +50,16 @@ class SingleSwitchPass {
STATUS ConcatBodySubgraphInputAndOutput();
bool IsLoop();
STATUS InsertMerge();
int GetSubgraphInputTensorIndex(const std::unique_ptr<SubGraphT> &subgraph, const std::unique_ptr<TensorT> &tensor);
int GetSubgraphOutputTensorIndex(const std::unique_ptr<SubGraphT> &subgraph, CNodeT *node);
STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node,
const std::vector<schema::CNodeT *> &subgraph_nodes);
STATUS UpdateSubgraphOutput(const size_t &subgraph_index, schema::CNodeT *partial_node,
const std::vector<schema::CNodeT *> &subgraph_nodes);
std::unique_ptr<schema::TensorT> NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor);
void RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph);
void DoubleIdx(uint32_t *idx);
std::unique_ptr<schema::TensorT> NewTensor(const std::unique_ptr<schema::TensorT> &in_tensor, bool with_data = false);
void IsolateUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph);
void UpdateSwitchOutputIndices(uint32_t *idx);
STATUS BodyGraphVariableInput(std::vector<size_t> *variable_input);
const size_t kSwitchCondIndex = 0;
const size_t kSwitchBodyIndex = 1;
@ -70,10 +73,7 @@ class SingleSwitchPass {
std::vector<schema::CNodeT *> this_graph_nodes_;
std::vector<schema::CNodeT *> body_graph_nodes_;
std::vector<schema::CNodeT *> cond_graph_nodes_;
std::vector<schema::CNodeT *> switch_users_;
size_t switch_node_index_ = -1;
size_t cond_node_index_ = -1;
size_t body_node_index_ = -1;
int32_t this_subgraph_index_ = -1;
int32_t cond_subgraph_index_ = -1;
int32_t body_subgraph_index_ = -1;

@ -346,10 +346,6 @@ STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) {
}
bool InferShapePass::Run(const FuncGraphPtr &func_graph) {
if (func_graph->has_flag("HasInferShaped")) {
return true;
}
if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) {
MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
return false;

Loading…
Cancel
Save