From 95211c8fceec3292f7cb3a9300818396c0a264b0 Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Tue, 22 Dec 2020 11:05:22 +0800 Subject: [PATCH] set root tensor for merge --- mindspore/lite/test/models_tf.cfg | 1 - .../graph/subgraph_node_pass.cc | 77 ++++-- .../graph/subgraph_node_pass.h | 9 +- .../legacy_optimizer/graph/switch_pass.cc | 220 +++++++++++------- .../legacy_optimizer/graph/switch_pass.h | 12 +- .../tools/optimizer/graph/infershape_pass.cc | 4 - 6 files changed, 202 insertions(+), 121 deletions(-) diff --git a/mindspore/lite/test/models_tf.cfg b/mindspore/lite/test/models_tf.cfg index 7712e1d401..e69de29bb2 100644 --- a/mindspore/lite/test/models_tf.cfg +++ b/mindspore/lite/test/models_tf.cfg @@ -1 +0,0 @@ -decoder_step_201217.pb 5 diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc index e2ad9c11a4..827619c2e5 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.cc @@ -28,30 +28,39 @@ namespace mindspore { namespace lite { -std::set SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr &subgraph, - schema::MetaGraphT *graph) { - std::set tensors_indices{}; +STATUS SubgraphNodePass::GetSubgraphAllTensorIndices(const std::unique_ptr &subgraph, + schema::MetaGraphT *graph, std::set *tensors_indices) { for (auto &node_idx : subgraph->nodeIndices) { + if (node_idx >= graph->nodes.size()) { + MS_LOG(ERROR) << "node_idx: " << node_idx << " bigger than graph->nodes.size(): " << graph->nodes.size(); + for (auto &subgraph : graph->subGraph) { + MS_LOG(ERROR) << subgraph->name << " : " << subgraph->nodeIndices; + } + return RET_ERROR; + } auto &node = graph->nodes.at(node_idx); for (auto &input_idx : node->inputIndex) { - tensors_indices.insert(input_idx); + tensors_indices->insert(input_idx); } for (auto &output_idx : node->outputIndex) { - tensors_indices.insert(output_idx); + tensors_indices->insert(output_idx); } } - return tensors_indices; + return RET_OK; +} + +bool SubgraphNodePass::IsNodeInputInSubgraph(const std::set &tensors_indices, + const std::unique_ptr &node, + const std::unique_ptr &subgraph) { + return std::any_of(node->inputIndex.begin(), node->inputIndex.end(), + [&tensors_indices, &subgraph](uint32_t idx) { return tensors_indices.count(idx) > 0; }); } -bool SubgraphNodePass::IsNodeInSubgraph(const std::set &tensors_indices, const std::unique_ptr &node, - const std::unique_ptr &subgraph) { - return (std::any_of(node->inputIndex.begin(), node->inputIndex.end(), - [&tensors_indices, &subgraph](uint32_t idx) { - return tensors_indices.count(idx) > 0 || IsContain(subgraph->inputIndices, idx); - })) && - (std::any_of(node->outputIndex.begin(), node->outputIndex.end(), [&tensors_indices, &subgraph](uint32_t idx) { - return tensors_indices.count(idx) > 0 || IsContain(subgraph->outputIndices, idx); - })); +bool SubgraphNodePass::IsNodeOutputInSubgraph(const std::set &tensors_indices, + const std::unique_ptr &node, + const std::unique_ptr &subgraph) { + return std::any_of(node->outputIndex.begin(), node->outputIndex.end(), + [&tensors_indices, &subgraph](uint32_t idx) { return tensors_indices.count(idx) > 0; }); } void SubgraphNodePass::DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph) { @@ -104,12 +113,42 @@ STATUS SubgraphNodePass::Run(schema::MetaGraphT *graph) { for (uint32_t i = 0; i < new_nodes.size(); i++) { if (!IsContain(old_nodes_, new_nodes[i])) { auto &node = graph->nodes.at(i); + std::vector contain_node_input_subgraphs{}; + std::vector contain_node_output_subgraphs{}; for (auto &subgraph : graph->subGraph) { - auto tensors_indices = GetSubgraphAllTensorIndices(subgraph, graph); - if (IsNodeInSubgraph(tensors_indices, node, subgraph)) { - IncreaseSubgraphNodeIndices(i, graph); - subgraph->nodeIndices.push_back(i); + std::set tensors_indices{}; + int ret = GetSubgraphAllTensorIndices(subgraph, graph, &tensors_indices); + if (ret != RET_OK) { + MS_LOG(ERROR) << "GetSubgraphAllTensorIndices failed."; + return ret; } + if (IsNodeInputInSubgraph(tensors_indices, node, subgraph)) { + contain_node_input_subgraphs.push_back(subgraph.get()); + } + if (IsNodeOutputInSubgraph(tensors_indices, node, subgraph)) { + contain_node_output_subgraphs.push_back(subgraph.get()); + } + } + if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 && + contain_node_output_subgraphs[0] != contain_node_input_subgraphs[0]) { + MS_LOG(ERROR) << "not support single node index insert."; + return RET_ERROR; + } + if (contain_node_input_subgraphs.size() == 1 && contain_node_output_subgraphs.size() == 1 && + contain_node_output_subgraphs[0] == contain_node_input_subgraphs[0]) { + IncreaseSubgraphNodeIndices(i, graph); + contain_node_input_subgraphs[0]->nodeIndices.push_back(i); + continue; + } + if (contain_node_input_subgraphs.size() == 1) { + IncreaseSubgraphNodeIndices(i, graph); + contain_node_input_subgraphs[0]->nodeIndices.push_back(i); + continue; + } + if (contain_node_output_subgraphs.size() == 1) { + IncreaseSubgraphNodeIndices(i, graph); + contain_node_output_subgraphs[0]->nodeIndices.push_back(i); + continue; } } } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h index 303310b9c3..594c9a0ca1 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/subgraph_node_pass.h @@ -36,9 +36,12 @@ class SubgraphNodePass : public GraphPass { private: void DecreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph); void IncreaseSubgraphNodeIndices(const size_t &node_idx, schema::MetaGraphT *graph); - std::set GetSubgraphAllTensorIndices(const std::unique_ptr &subgraph, schema::MetaGraphT *graph); - bool IsNodeInSubgraph(const std::set &tensors_indices, const std::unique_ptr &node, - const std::unique_ptr &subgraph); + STATUS GetSubgraphAllTensorIndices(const std::unique_ptr &subgraph, schema::MetaGraphT *graph, + std::set *tensors_indices); + bool IsNodeInputInSubgraph(const std::set &tensors_indices, const std::unique_ptr &node, + const std::unique_ptr &subgraph); + bool IsNodeOutputInSubgraph(const std::set &tensors_indices, const std::unique_ptr &node, + const std::unique_ptr &subgraph); std::vector old_nodes_; }; } // namespace lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc index d296919264..773f1c4b7d 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.cc @@ -16,6 +16,7 @@ #include #include +#include #include #include "tools/converter/legacy_optimizer/graph/switch_pass.h" #include "src/common/log_adapter.h" @@ -47,7 +48,10 @@ STATUS SwitchPass::Run(mindspore::schema::MetaGraphT *graph) { STATUS SingleSwitchPass::DoubleSwitchOutput() { origin_switch_output_tensor_indices_ = switch_node_->outputIndex; - MS_ASSERT(origin_switch_output_tensor_indices_.size() == cond_partial_node_->inputIndex.szie()); + if (origin_switch_output_tensor_indices_.size() != cond_partial_node_->inputIndex.size()) { + MS_LOG(ERROR) << "switch node: " << switch_node_->name << " input or output number is not right."; + return RET_ERROR; + } for (size_t i = 0; i < origin_switch_output_tensor_indices_.size(); i++) { auto &switch_out_tensor = graph_->allTensors.at(origin_switch_output_tensor_indices_[i]); const auto &cond_partial_input_tensor = graph_->allTensors.at(cond_partial_node_->inputIndex[i]); @@ -60,7 +64,7 @@ STATUS SingleSwitchPass::DoubleSwitchOutput() { return RET_OK; } -void SingleSwitchPass::DoubleIdx(uint32_t *idx) { +void SingleSwitchPass::UpdateSwitchOutputIndices(uint32_t *idx) { auto iter = std::find(switch_node_->outputIndex.begin(), switch_node_->outputIndex.end(), *idx); if (iter != switch_node_->outputIndex.end()) { int pos = iter - switch_node_->outputIndex.begin(); @@ -69,25 +73,21 @@ void SingleSwitchPass::DoubleIdx(uint32_t *idx) { } STATUS SingleSwitchPass::UpdateSwitchUser() { - std::vector switch_users; for (auto &node_idx : graph_->subGraph.at(this_subgraph_index_)->nodeIndices) { auto &node = graph_->nodes.at(node_idx); for (auto &idx : node->inputIndex) { - if (IsContain(switch_node_->outputIndex, idx)) { - switch_users.push_back(node.get()); - } - DoubleIdx(&idx); + UpdateSwitchOutputIndices(&idx); } } // update graph switch user for (auto &subgraph : graph_->subGraph) { for (auto &idx : subgraph->outputIndices) { - DoubleIdx(&idx); + UpdateSwitchOutputIndices(&idx); } } for (auto &idx : graph_->outputIndex) { - DoubleIdx(&idx); + UpdateSwitchOutputIndices(&idx); } return RET_OK; @@ -104,20 +104,71 @@ bool SingleSwitchPass::IsLoop() { return false; } -std::unique_ptr SingleSwitchPass::NewTensor(const std::unique_ptr &in_tensor) { +std::unique_ptr SingleSwitchPass::NewTensor(const std::unique_ptr &in_tensor, + bool with_data) { auto out_tensor = std::make_unique(); out_tensor->nodeType = in_tensor->nodeType; out_tensor->dims = in_tensor->dims; out_tensor->dataType = in_tensor->dataType; - out_tensor->data = in_tensor->data; out_tensor->format = in_tensor->format; + if (with_data) { + out_tensor->data = in_tensor->data; + } return out_tensor; } +STATUS SingleSwitchPass::BodyGraphVariableInput(std::vector *variable_input) { + auto &body_fg = graph_->subGraph.at(body_subgraph_index_); + auto body_fg_output = body_fg->outputIndices; + for (auto &subgraph_output : body_fg_output) { + for (auto &node : body_graph_nodes_) { + if (node != nullptr && IsContain(node->outputIndex, subgraph_output)) { + int partial_idx = GetSubgraphOutputTensorIndex(body_fg, node); + if (partial_idx == -1) { + MS_LOG(ERROR) << "get input index failed."; + return RET_ERROR; + } + (*variable_input).emplace_back(partial_idx); + } + } + } + return RET_OK; +} + STATUS SingleSwitchPass::InsertMerge() { - int ret = RET_OK; + // update body graph output + auto &body_fg = graph_->subGraph.at(body_subgraph_index_); + body_fg->outputIndices.assign(body_to_cond_partial_node_->inputIndex.begin(), + body_to_cond_partial_node_->inputIndex.end()); + + // remove body_to_cond_partial_node_ from body_graph_nodes_ + for (auto it = body_graph_nodes_.begin(); it != body_graph_nodes_.end();) { + if (*it == body_to_cond_partial_node_) { + it = body_graph_nodes_.erase(it); + } else { + it++; + } + } + + // isolate body_to_cond_partial_node_ + IsolateUselessNode(body_to_cond_partial_node_, graph_); + + std::vector variable_input{}; + int ret = BodyGraphVariableInput(&variable_input); + if (ret != RET_OK) { + MS_LOG(ERROR) << "get body graph variable input failed, ret: " << ret; + return ret; + } + + std::vector const_input{}; + for (size_t i = 0; i < body_partial_node_->inputIndex.size(); i++) { + if (IsContain(variable_input, i)) { + continue; + } + const_input.push_back(i); + } + auto merge_node = std::unique_ptr(new (std::nothrow) CNodeT); - MS_ASSERT(merge_node != nullptr); auto primitiveT = std::unique_ptr(new (std::nothrow) PrimitiveT); MS_ASSERT(primitiveT != nullptr); merge_node->primitive = std::move(primitiveT); @@ -129,8 +180,6 @@ STATUS SingleSwitchPass::InsertMerge() { MS_ASSERT(merge_param != nullptr); merge_node->primitive->value.value = merge_param.release(); - merge_node->inputIndex.assign(cond_partial_node_->inputIndex.begin(), cond_partial_node_->inputIndex.end()); - // merge node output is same as switch for (auto &out_index : origin_switch_output_tensor_indices_) { auto &switch_out_tensor = graph_->allTensors.at(out_index); @@ -139,12 +188,30 @@ STATUS SingleSwitchPass::InsertMerge() { merge_node->outputIndex.push_back(graph_->allTensors.size() - 1); } - // double merge inputs to contain the outputs of body node - for (auto &index : cond_partial_node_->inputIndex) { - auto &in_tensor = graph_->allTensors.at(index); - auto tensor = NewTensor(in_tensor); - graph_->allTensors.push_back(std::move(tensor)); - merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); + merge_node->inputIndex.assign(cond_partial_node_->inputIndex.begin(), cond_partial_node_->inputIndex.end()); + + std::set input_set{}; + for (auto &iter : merge_node->inputIndex) { + if (input_set.find(iter) != input_set.end()) { + auto &in_tensor = graph_->allTensors.at(iter); + auto tensor = NewTensor(in_tensor, true); + graph_->allTensors.push_back(std::move(tensor)); + iter = graph_->allTensors.size() - 1; + } + input_set.insert(iter); + } + + // double merge inputs to contain the outputs of body node + auto old_merge_input = merge_node->inputIndex; + for (size_t i = 0; i < old_merge_input.size(); i++) { + auto &in_tensor = graph_->allTensors.at(old_merge_input[i]); + if (IsContain(const_input, i)) { + merge_node->inputIndex.push_back(old_merge_input[i]); + } else { + auto tensor = NewTensor(in_tensor); + graph_->allTensors.push_back(std::move(tensor)); + merge_node->inputIndex.push_back(graph_->allTensors.size() - 1); + } } // insert merge node before the cond graph @@ -182,46 +249,12 @@ STATUS SingleSwitchPass::InsertMerge() { graph_->nodes.push_back(std::move(merge_node)); graph_->subGraph.at(this_subgraph_index_)->nodeIndices.push_back(graph_->nodes.size() - 1); - // update bodu graph output - graph_->subGraph.at(body_subgraph_index_) - ->outputIndices.assign(body_to_cond_partial_node_->inputIndex.begin(), - body_to_cond_partial_node_->inputIndex.end()); - - // erase body_to_cond_partial_node_ - RemoveUselessNode(body_to_cond_partial_node_, graph_); - return ret; + return RET_OK; } -void SingleSwitchPass::RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph) { +void SingleSwitchPass::IsolateUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph) { partial_node->inputIndex.clear(); partial_node->outputIndex.clear(); - - int pos = -1; - for (size_t i = 0; i < graph->nodes.size(); ++i) { - if (graph->nodes.at(i).get() == partial_node) { - pos = i; - break; - } - } - - if (pos == -1) { - return; - } - - graph->nodes.erase(graph->nodes.begin() + pos); - - for (auto &subgraph : graph->subGraph) { - for (auto it = subgraph->nodeIndices.begin(); it != subgraph->nodeIndices.end();) { - if (*it == static_cast(pos)) { - it = subgraph->nodeIndices.erase(it); - } else { - if (*it > static_cast(pos)) { - (*it)--; - } - it++; - } - } - } } size_t SingleSwitchPass::InitThisGraphIndex() { @@ -265,12 +298,10 @@ STATUS SingleSwitchPass::Init() { for (auto &out_index : iter->get()->outputIndex) { if (out_index == switch_node_->inputIndex[kSwitchCondIndex]) { cond_partial_node_ = iter->get(); - cond_node_index_ = iter - graph_->nodes.begin(); find_cond_node = true; } if (out_index == switch_node_->inputIndex[kSwitchBodyIndex]) { body_partial_node_ = iter->get(); - body_node_index_ = iter - graph_->nodes.begin(); find_body_node = true; } } @@ -301,6 +332,41 @@ STATUS SingleSwitchPass::Init() { return RET_OK; } +int SingleSwitchPass::GetSubgraphInputTensorIndex(const std::unique_ptr &subgraph, + const std::unique_ptr &tensor) { + int partial_idx = -1; + if (tensor->name.find("_input_") != std::string::npos) { + // get parameter input index k. subgraph name + “_input_" + "k" + auto pos = subgraph->name.size() + sizeof("_input_"); + auto pos2 = tensor->name.find('_', pos); + auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); + partial_idx = std::stoi(idx_str); + } + + if (tensor->name.find("_output_") != std::string::npos) { + // get parameter input index k. subgraph name + “_output_" + "k" + auto pos = subgraph->name.size() + sizeof("_output_"); + auto pos2 = tensor->name.find('_', pos); + auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); + partial_idx = std::stoi(idx_str); + } + return partial_idx; +} + +int SingleSwitchPass::GetSubgraphOutputTensorIndex(const std::unique_ptr &subgraph, CNodeT *node) { + int partial_idx = -1; + if (node->name == "LogicalAnd") { + partial_idx = 0; + } else { + // get parameter input index k. subgraph name + “_output_" + "k" + auto pos = subgraph->name.size() + sizeof("_output_"); + auto pos2 = node->name.find('_', pos); + auto idx_str = node->name.substr(pos - 1, pos2 - pos + 1); + partial_idx = std::stoi(idx_str); + } + return partial_idx; +} + STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node, const std::vector &subgraph_nodes) { if (partial_node == nullptr || subgraph_nodes.empty()) { @@ -315,27 +381,11 @@ STATUS SingleSwitchPass::UpdateSubgraphInput(const size_t &subgraph_index, schem std::vector> tmp_inputs_order{}; for (unsigned int &subgraph_input : subgraph_inputs) { auto &tensor = graph_->allTensors.at(subgraph_input); - if (tensor->name.size() < subgraph->name.size() + 8) { - MS_LOG(ERROR) << "tensor name: " << tensor->name << " not right."; + int partial_idx = GetSubgraphInputTensorIndex(subgraph, tensor); + if (partial_idx == -1) { + MS_LOG(ERROR) << "get input index failed."; return RET_ERROR; } - int partial_idx = -1; - if (tensor->name.find("_input_") != std::string::npos) { - // get parameter input index k. subgraph name + “_input_" + "k" - auto pos = subgraph->name.size() + sizeof("_input_"); - auto pos2 = tensor->name.find('_', pos); - auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); - partial_idx = std::stoi(idx_str); - } - - if (tensor->name.find("_output_") != std::string::npos) { - // get parameter input index k. subgraph name + “_output_" + "k" - auto pos = subgraph->name.size() + sizeof("_output_"); - auto pos2 = tensor->name.find('_', pos); - auto idx_str = tensor->name.substr(pos - 1, pos2 - pos + 1); - partial_idx = std::stoi(idx_str); - } - subgraph_input_map.insert(std::pair{subgraph_input, partial_inputs[partial_idx]}); tmp_inputs_order.emplace_back(partial_idx, partial_inputs[partial_idx]); } @@ -374,15 +424,10 @@ STATUS SingleSwitchPass::UpdateSubgraphOutput(const size_t &subgraph_index, sche for (unsigned int &subgraph_output : subgraph_outputs) { for (auto &node : subgraph_nodes) { if (IsContain(node->outputIndex, subgraph_output)) { - int partial_idx = -1; - if (node->name == "LogicalAnd") { - partial_idx = 0; - } else { - // get parameter input index k. subgraph name + “_output_" + "k" - auto pos = subgraph->name.size() + sizeof("_output_"); - auto pos2 = node->name.find('_', pos); - auto idx_str = node->name.substr(pos - 1, pos2); - partial_idx = std::stoi(idx_str); + int partial_idx = GetSubgraphOutputTensorIndex(subgraph, node); + if (partial_idx == -1) { + MS_LOG(ERROR) << "get input index failed."; + return RET_ERROR; } subgraph_output_map.insert(std::pair{subgraph_output, partial_outputs[partial_idx]}); tmp_outputs_order.emplace_back(partial_idx, partial_outputs[partial_idx]); @@ -473,7 +518,6 @@ STATUS SingleSwitchPass::Run() { MS_LOG(ERROR) << "ConcatBodySubgraphInputAndOutput failed, ret: " << ret; return ret; } - return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h index f65eab862a..8d82054d0e 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/switch_pass.h @@ -50,13 +50,16 @@ class SingleSwitchPass { STATUS ConcatBodySubgraphInputAndOutput(); bool IsLoop(); STATUS InsertMerge(); + int GetSubgraphInputTensorIndex(const std::unique_ptr &subgraph, const std::unique_ptr &tensor); + int GetSubgraphOutputTensorIndex(const std::unique_ptr &subgraph, CNodeT *node); STATUS UpdateSubgraphInput(const size_t &subgraph_index, schema::CNodeT *partial_node, const std::vector &subgraph_nodes); STATUS UpdateSubgraphOutput(const size_t &subgraph_index, schema::CNodeT *partial_node, const std::vector &subgraph_nodes); - std::unique_ptr NewTensor(const std::unique_ptr &in_tensor); - void RemoveUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph); - void DoubleIdx(uint32_t *idx); + std::unique_ptr NewTensor(const std::unique_ptr &in_tensor, bool with_data = false); + void IsolateUselessNode(schema::CNodeT *partial_node, schema::MetaGraphT *graph); + void UpdateSwitchOutputIndices(uint32_t *idx); + STATUS BodyGraphVariableInput(std::vector *variable_input); const size_t kSwitchCondIndex = 0; const size_t kSwitchBodyIndex = 1; @@ -70,10 +73,7 @@ class SingleSwitchPass { std::vector this_graph_nodes_; std::vector body_graph_nodes_; std::vector cond_graph_nodes_; - std::vector switch_users_; size_t switch_node_index_ = -1; - size_t cond_node_index_ = -1; - size_t body_node_index_ = -1; int32_t this_subgraph_index_ = -1; int32_t cond_subgraph_index_ = -1; int32_t body_subgraph_index_ = -1; diff --git a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc index d01ff04fa5..fb8b84cddf 100644 --- a/mindspore/lite/tools/optimizer/graph/infershape_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/infershape_pass.cc @@ -346,10 +346,6 @@ STATUS InferShapePass::SwitchCNodeInferShape(const CNodePtr &switch_cnode) { } bool InferShapePass::Run(const FuncGraphPtr &func_graph) { - if (func_graph->has_flag("HasInferShaped")) { - return true; - } - if (fmk_type != lite::converter::FmkType_TF && fmk_type != lite::converter::FmkType_TFLITE) { MS_LOG(INFO) << "The framework type of model should be tf/tflite."; return false;