From 46a22f4a7bd2e70d42949099d8b5fef28581f0ef Mon Sep 17 00:00:00 2001 From: xuanyue Date: Fri, 5 Feb 2021 16:02:51 +0800 Subject: [PATCH] match side effect in lite --- mindspore/lite/test/CMakeLists.txt | 2 +- .../lite/tools/anf_exporter/anf_exporter.cc | 47 --------------- .../lite/tools/anf_exporter/anf_exporter.h | 1 - mindspore/lite/tools/converter/CMakeLists.txt | 2 +- .../lite/tools/converter/anf_transform.cc | 4 +- ...ve_pass.cc => redundant_op_remove_pass.cc} | 60 ++++++++++++------- ...move_pass.h => redundant_op_remove_pass.h} | 14 ++--- 7 files changed, 49 insertions(+), 81 deletions(-) rename mindspore/lite/tools/optimizer/graph/{identity_remove_pass.cc => redundant_op_remove_pass.cc} (67%) rename mindspore/lite/tools/optimizer/graph/{identity_remove_pass.h => redundant_op_remove_pass.h} (72%) diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 413179be4a..74a5e2265b 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -225,7 +225,7 @@ if(ENABLE_CONVERTER) ${LITE_DIR}/tools/optimizer/graph/update_conv2d_param_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_cast_node_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/unused_transpose_node_remove_pass.cc - ${LITE_DIR}/tools/optimizer/graph/identity_remove_pass.cc + ${LITE_DIR}/tools/optimizer/graph/redundant_op_remove_pass.cc ${LITE_DIR}/tools/optimizer/graph/infershape_pass.cc ${LITE_DIR}/tools/optimizer/graph/slice_prepose_pass.cc ${LITE_DIR}/tools/optimizer/graph/mindir_adjust_pass.cc diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index a0330d1c36..4654158871 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -59,41 +59,6 @@ void AnfExporter::RemoveIfMakeTuple(const CNodePtr &cnode) { } } -void AnfExporter::RemoveIfDepend(const CNodePtr &cnode) { - bool hasDepend = false; - std::vector inputs; - inputs.clear(); - - inputs.emplace_back(cnode->input(0)); - for (size_t i = 1; i < cnode->inputs().size(); ++i) { - AnfNodePtr inputNode = cnode->input(i); - if (!inputNode->isa()) { - inputs.emplace_back(cnode->input(i)); - continue; - } - auto dependNode = utils::cast(inputNode); - if (IsPrimitiveCNode(dependNode, schema::PrimitiveType_Depend) || - IsPrimitiveCNode(dependNode, schema::PrimitiveType_ControlDepend)) { - hasDepend = true; - bool maskOut = (dependNode->inputs().size() == 3); - for (size_t j = 1; j < dependNode->inputs().size(); ++j) { - AnfNodePtr dependInputNode = dependNode->input(j); - if (dependInputNode->isa()) { - inputs.emplace_back(dependInputNode); - if (maskOut) { - break; - } - } - } - } else { - inputs.emplace_back(cnode->input(i)); - } - } - if (hasDepend) { - cnode->set_inputs(inputs); - } -} - int AnfExporter::ConvertQuantParam(const std::unique_ptr &meta_graph, const std::shared_ptr &primitive, const std::unique_ptr &dst_node) { @@ -286,23 +251,11 @@ int AnfExporter::Anf2Fb(const FuncGraphPtr &func_graph, const std::unique_ptrType() == schema::PrimitiveType_TupleGetItem) || -#ifdef SUPPORT_TRAIN - (primitive_c->Type() == schema::PrimitiveType_Depend) || - (primitive_c->Type() == schema::PrimitiveType_ControlDepend) || -#endif (primitive_c->Type() == schema::PrimitiveType_MakeTuple)) { continue; } -#ifndef SUPPORT_TRAIN - RemoveIfMakeTuple(cnode); -#endif auto primT = primitive_c->primitiveT(); auto node = std::make_unique(); if (node == nullptr) { diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.h b/mindspore/lite/tools/anf_exporter/anf_exporter.h index 64bae9af1e..2a0ea78ce5 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.h +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.h @@ -41,7 +41,6 @@ class AnfExporter { int SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, schema::CNodeT *fb_node); static void RemoveIfMakeTuple(const CNodePtr &cnode); - static void RemoveIfDepend(const CNodePtr &cnode); protected: int ConvertInputCNode(const std::shared_ptr &input_anode, schema::CNodeT *output_cnode); diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index 01b880ce2b..3420b057c3 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -59,7 +59,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ../optimizer/graph/update_conv2d_param_pass.cc ../optimizer/graph/unused_cast_node_remove_pass.cc ../optimizer/graph/unused_transpose_node_remove_pass.cc - ../optimizer/graph/identity_remove_pass.cc + ../optimizer/graph/redundant_op_remove_pass.cc ../optimizer/graph/infershape_pass.cc ../optimizer/graph/slice_prepose_pass.cc ../optimizer/graph/mindir_adjust_pass.cc diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index dad863ece9..49e4156fd2 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -34,7 +34,7 @@ #include "tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.h" #include "tools/optimizer/graph/mindir_adjust_pass.h" #include "tools/optimizer/graph/mindir_inputs_adjust_pass.h" -#include "tools/optimizer/graph/identity_remove_pass.h" +#include "tools/optimizer/graph/redundant_op_remove_pass.h" #include "tools/optimizer/graph/weight_format_hardcode_pass.h" #include "tools/optimizer/graph/weight_format_transform_pass.h" #include "tools/optimizer/graph/clip_convert_activation_pass.h" @@ -144,7 +144,7 @@ int AnfTransform::AddConvertPass(const std::shared_ptr &opt int AnfTransform::AddConstFoldPass(const std::shared_ptr &optimizer, const converter::Flags *config) { auto const_fold_pm = std::make_shared("const fold fusion pass manager", false); - const_fold_pm->AddPass(std::make_shared()); + const_fold_pm->AddPass(std::make_shared()); if (!config->trainModel) { auto inne_context_ptr = std::make_shared(); inne_context_ptr->Init(); diff --git a/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc similarity index 67% rename from mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc rename to mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc index 76f9c52284..54b561975b 100644 --- a/mindspore/lite/tools/optimizer/graph/identity_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -13,37 +13,41 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "tools/optimizer/graph/identity_remove_pass.h" +#include "tools/optimizer/graph/redundant_op_remove_pass.h" +#include #include "mindspore/lite/include/errorcode.h" #include "src/ops/primitive_c.h" namespace mindspore::opt { -int RemoveIdentityOpPass::ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { +namespace { +constexpr size_t InputDoubleNum = 2; +constexpr size_t InputTripleNum = 3; +constexpr auto kNameLoad = "Load"; +constexpr auto kNameUpdateState = "UpdateState"; +} // namespace +int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { if (!utils::isa(anf_node)) { MS_LOG(DEBUG) << "anf node is node a cnode."; return lite::RET_NO_CHANGE; } auto type = opt::GetCNodeType(anf_node); - if (type != schema::PrimitiveType_Identity) { - MS_LOG(DEBUG) << "anf node is not a identity node."; - return lite::RET_NO_CHANGE; - } - auto identity_cnode = anf_node->cast(); - if (identity_cnode->inputs().size() != lite::kDoubleNum) { - MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; - remove_cnode_.insert(anf_node); - return lite::RET_NO_CHANGE; - } else { - bool replace_succ = manager->Replace(anf_node, identity_cnode->input(1)); - if (!replace_succ) { - MS_LOG(ERROR) << "replace identity failed."; - return lite::RET_ERROR; + auto cnode = anf_node->cast(); + if (type == schema::PrimitiveType_Identity) { + if (cnode->size() != InputDoubleNum) { + MS_LOG(DEBUG) << "The node inputs size is bigger than 1"; + remove_cnode_.insert(anf_node); + return lite::RET_NO_CHANGE; } } + bool replace_succ = manager->Replace(anf_node, cnode->input(1)); + if (!replace_succ) { + MS_LOG(ERROR) << "replace redundant op failed."; + return lite::RET_ERROR; + } return RET_OK; } -int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { +int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) { if (!utils::isa(anf_node)) { MS_LOG(DEBUG) << "anf node is node a cnode."; return lite::RET_NO_CHANGE; @@ -53,7 +57,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const return lite::RET_NO_CHANGE; } auto cnode = anf_node->cast(); - if (cnode->inputs().size() != 3) { + if (cnode->inputs().size() != InputTripleNum) { MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size(); return RET_ERROR; } @@ -81,7 +85,7 @@ int RemoveIdentityOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const return lite::RET_OK; } -bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { +bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); auto manager = func_graph->manager(); MS_ASSERT(manager != nullptr); @@ -93,10 +97,22 @@ bool RemoveIdentityOpPass::Run(const FuncGraphPtr &func_graph) { } auto type = opt::GetCNodeType(node); if (type == schema::PrimitiveType_Identity) { - status = ReplaceIdentity(node, manager); - } else if (type == schema::PrimitiveType_TupleGetItem) { + status = ReplaceOp(node, manager); + } + if (CheckPrimitiveType(node, std::make_shared(kNameLoad))) { + status = ReplaceOp(node, manager); + } + if (CheckPrimitiveType(node, std::make_shared(kNameUpdateState))) { + status = ReplaceOp(node, manager); + } + if (type == schema::PrimitiveType_Depend || + type == schema::PrimitiveType_ControlDepend) { // ControlDepend delete next version. + status = ReplaceOp(node, manager); + } + if (type == schema::PrimitiveType_TupleGetItem) { status = ReplaceTupleGetItem(node, manager); - } else if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) { + } + if (type == schema::PrimitiveType_If || type == schema::PrimitiveType_While) { auto sub_func_graph = GetValueNode(node->cast()->input(1)); if (sub_func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); diff --git a/mindspore/lite/tools/optimizer/graph/identity_remove_pass.h b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h similarity index 72% rename from mindspore/lite/tools/optimizer/graph/identity_remove_pass.h rename to mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h index e9799fe896..8e7237bbe4 100644 --- a/mindspore/lite/tools/optimizer/graph/identity_remove_pass.h +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ -#define MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ +#ifndef MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_ +#define MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_ #include #include #include "backend/optimizer/common/pass.h" @@ -24,11 +24,11 @@ using mindspore::lite::converter::FmkType; namespace mindspore::opt { -class RemoveIdentityOpPass : public Pass { +class RemoveRedundantOpPass : public Pass { public: - RemoveIdentityOpPass() : Pass("remove_identity_pass") {} - ~RemoveIdentityOpPass() override = default; - int ReplaceIdentity(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); + RemoveRedundantOpPass() : Pass("remove_redundant_op_pass") {} + ~RemoveRedundantOpPass() override = default; + int ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); int ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager); bool Run(const FuncGraphPtr &graph) override; @@ -36,4 +36,4 @@ class RemoveIdentityOpPass : public Pass { std::set remove_cnode_; }; } // namespace mindspore::opt -#endif // MINDSPORE_LITE_SRC_PASS_REMOVE_IDENTITY_PASS_H_ +#endif // MINDSPORE_LITE_SRC_PASS_REDUNDANT_OP_REMOVE_PASS_H_