diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt index 7a1c6b1f22..406202a274 100755 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/CMakeLists.txt @@ -19,6 +19,7 @@ file(GLOB GRAPH_PASS ${CMAKE_CURRENT_SOURCE_DIR}/select_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_node_pass.cc ${CMAKE_CURRENT_SOURCE_DIR}/subgraph_tensor_pass.cc + ${CMAKE_CURRENT_SOURCE_DIR}/nested_loop_expand_pass.cc ) set_property(SOURCE ${GRAPH_PASS} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_LITE) add_library(graph_pass_mid OBJECT ${GRAPH_PASS}) diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc new file mode 100644 index 0000000000..27d3afd9e0 --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.cc @@ -0,0 +1,98 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h" +#include "src/common/log_adapter.h" +#include "src/common/utils.h" +#include "tools/common/graph_util.h" +#include "include/errorcode.h" +#include "schema/inner/model_generated.h" + +namespace mindspore { +namespace lite { +bool NestedLoopExpandPass::IsNestedPartial(const std::unique_ptr &node) { + if (node->primitive->value.type != PrimitiveType_Partial) { + return false; + } + auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex; + auto &this_subgraph = graph_->subGraph.at(subgraph_idx); + + for (auto &node_idx : this_subgraph->nodeIndices) { + auto &cnode = graph_->nodes.at(node_idx); + if (cnode->primitive->value.type == PrimitiveType_Partial) { + return true; + } + } + return false; +} + +void NestedLoopExpandPass::ReplacePartialNodeWithSubgraph(const std::unique_ptr &main_graph) { + bool is_changed = false; + for (auto &node_idx : main_graph->nodeIndices) { + auto &node = graph_->nodes.at(node_idx); + if (!IsNestedPartial(node)) { + continue; + } + is_changed = true; + auto subgraph_idx = ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex; + auto &this_subgraph = graph_->subGraph.at(subgraph_idx); + subgraph_to_drop_.push_back(subgraph_idx); + auto partial_pos = std::find(main_graph->nodeIndices.begin(), main_graph->nodeIndices.end(), node_idx); + std::vector tmp; + tmp.assign(main_graph->nodeIndices.begin(), partial_pos); + tmp.insert(tmp.end(), this_subgraph->nodeIndices.begin(), this_subgraph->nodeIndices.end()); + tmp.insert(tmp.end(), partial_pos + 1, main_graph->nodeIndices.end()); + main_graph->nodeIndices.assign(tmp.begin(), tmp.end()); + } + + if (is_changed) { + ReplacePartialNodeWithSubgraph(main_graph); + } +} + +STATUS NestedLoopExpandPass::Run(schema::MetaGraphT *graph) { + graph_ = graph; + auto &main_graph = graph_->subGraph[0]; + + ReplacePartialNodeWithSubgraph(main_graph); + + for (auto idx : subgraph_to_drop_) { + graph_->subGraph.at(idx) = nullptr; + } + + for (auto it = graph_->subGraph.begin(); it != graph_->subGraph.end();) { + if ((*it) == nullptr) { + it = graph_->subGraph.erase(it); + } else { + it++; + } + } + + for (auto &node : graph_->nodes) { + if (node->primitive->value.type == PrimitiveType_Partial) { + ((schema::PartialT *)(node->primitive->value.value))->subGraphIndex -= subgraph_to_drop_.size(); + } + } + + return RET_OK; +} + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h b/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h new file mode 100644 index 0000000000..665c78ba4f --- /dev/null +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/nested_loop_expand_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_NESTED_LOOP_EXPAND_PASS_H +#define MINDSPORE_LITE_NESTED_LOOP_EXPAND_PASS_H + +#include +#include +#include +#include +#include "tools/converter/optimizer.h" + +namespace mindspore { +namespace lite { +class NestedLoopExpandPass : public GraphPass { + public: + NestedLoopExpandPass() = default; + + ~NestedLoopExpandPass() override = default; + + STATUS Run(schema::MetaGraphT *graph) override; + + private: + bool IsNestedPartial(const std::unique_ptr &node); + + void ReplacePartialNodeWithSubgraph(const std::unique_ptr &main_graph); + + schema::MetaGraphT *graph_ = nullptr; + + std::vector subgraph_to_drop_{}; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_PREDICT_ISOLATED_NODE_REMOVE_PASS_H diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc index b078967537..61266115cc 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/tensor_name_pass.cc @@ -35,7 +35,7 @@ STATUS TensorNamePass::Run(schema::MetaGraphT *graph) { auto tensor_id = node->inputIndex.at(i); auto &tensor = graph->allTensors.at(tensor_id); if (tensor->name.empty()) { - MS_LOG(WARNING) << "input tensor (id = " << tensor_id << ") name is null"; + MS_LOG(DEBUG) << "input tensor (id = " << tensor_id << ") name is null"; tensor->name = node->name + "/input-" + std::to_string(i); } } diff --git a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc index 102428063b..e5c4185189 100644 --- a/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/update_conv2d_param_pass.cc @@ -57,27 +57,27 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) { auto conv2d_cnode = node->cast(); auto primitive_c = GetValueNode>(conv2d_cnode->input(0)); if (primitive_c == nullptr) { - MS_LOG(ERROR) << "Conv2D node has no primitiveC."; + MS_LOG(DEBUG) << "Conv2D node has no primitiveC."; continue; } auto primT = primitive_c->primitiveT(); if (primT == nullptr) { - MS_LOG(ERROR) << "Conv2D node has no primitiveT."; + MS_LOG(DEBUG) << "Conv2D node has no primitiveT."; continue; } auto conv2d_primt = primT->value.AsConv2D(); auto weight_node = conv2d_cnode->input(lite::kAnfPopulaterInputNumTwo); if (weight_node == nullptr) { - MS_LOG(ERROR) << "Conv2D weight node is nullptr."; + MS_LOG(DEBUG) << "Conv2D weight node is nullptr."; continue; } if (!weight_node->isa()) { - MS_LOG(ERROR) << "Conv2D weight node is not parameter."; + MS_LOG(DEBUG) << "Conv2D weight node is not parameter."; continue; } auto weight_param = weight_node->cast(); if (!weight_param->has_default()) { - MS_LOG(ERROR) << "Conv2D weight node is not parameter."; + MS_LOG(DEBUG) << "Conv2D weight node is not parameter."; continue; } auto default_param = weight_param->default_param(); diff --git a/mindspore/lite/tools/optimizer/graph/while_pass.cc b/mindspore/lite/tools/optimizer/graph/while_pass.cc index c80f5e8741..486f77f7bc 100644 --- a/mindspore/lite/tools/optimizer/graph/while_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/while_pass.cc @@ -44,29 +44,11 @@ ValueNodePtr WhilePass::GetSwitchAnfPrim() { return nullptr; } - auto partial_prim = std::make_shared(switch_primitiveT); + auto partial_prim = std::make_shared(switch_primitiveT); ValueNodePtr partial_anf_prim = NewValueNode(partial_prim); return partial_anf_prim; } -void WhilePass::ReplaceInput(const std::vector &node_list, AnfNodePtr new_input_cnode, - std::string para_name) { - for (auto &node : node_list) { - if (utils::isa(node)) { - auto cnode = utils::cast(node); - for (size_t k = 0; k < cnode->inputs().size(); k++) { - if (!utils::isa(cnode->input(k))) { - continue; - } - auto para_input = utils::cast(cnode->input(k)); - if (para_input->name() == para_name) { - cnode->set_input(k, new_input_cnode); - } - } - } - } -} - bool WhilePass::Run(const FuncGraphPtr &graph) { auto node_list = TopoSort(graph->get_return()); static int count = 0; @@ -87,34 +69,23 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { // the order is fixed. auto cond_vnode = while_cnode->input(kWhileCondIndex); auto body_vnode = while_cnode->input(kWhileBodyIndex); - - // body_vnode->cast()->set_value() auto cond_fg = GetValueNode>(cond_vnode); auto body_fg = GetValueNode>(body_vnode); - if (cond_fg == nullptr || body_fg == nullptr) { MS_LOG(ERROR) << "Get value as func_graph failed."; lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_FAILED); return false; } - - // create cond partial cnode std::vector cond_partial_op_inputs{cond_vnode}; - - // create body partial cnode std::vector body_partial_op_inputs{body_vnode}; - - // add while op input to cond_cnode and body_cnode cond_partial_op_inputs.insert(cond_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, while_cnode->inputs().end()); body_partial_op_inputs.insert(body_partial_op_inputs.end(), while_cnode->inputs().begin() + kWhileMinInputSize, while_cnode->inputs().end()); - static int idx = 0; auto cond_partial_node = graph->NewCNode(cond_partial_op_inputs); cond_partial_node->set_fullname_with_scope("Partial-while-cond-" + std::to_string(idx)); cond_partial_node->set_abstract(cond_fg->output()->abstract()); - auto body_partial_node = graph->NewCNode(body_partial_op_inputs); body_partial_node->set_fullname_with_scope("Partial-while-body-" + std::to_string(idx)); idx++; @@ -166,7 +137,6 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { } abstract_list.push_back(cnode->abstract()); } - switch_cnode->set_abstract(std::make_shared(abstract_list)); // create cond partial cnode @@ -176,7 +146,6 @@ bool WhilePass::Run(const FuncGraphPtr &graph) { manager->SetEdge(node_user.first, node_user.second, switch_cnode); } } - return true; } } // namespace mindspore::opt diff --git a/mindspore/lite/tools/optimizer/graph/while_pass.h b/mindspore/lite/tools/optimizer/graph/while_pass.h index e37e7eeb90..21ce33d855 100644 --- a/mindspore/lite/tools/optimizer/graph/while_pass.h +++ b/mindspore/lite/tools/optimizer/graph/while_pass.h @@ -32,7 +32,6 @@ class WhilePass : public Pass { bool Run(const FuncGraphPtr &graph) override; private: - void ReplaceInput(const std::vector &node_list, AnfNodePtr new_input_cnode, std::string para_name); ValueNodePtr GetSwitchAnfPrim(); const size_t kWhileMinInputSize = 3;