From 0d141c1d02f664050e66d6b086cdd889ac92539c Mon Sep 17 00:00:00 2001 From: mengyuanli Date: Tue, 9 Mar 2021 10:38:35 +0800 Subject: [PATCH] fix GRU fusion for encorder_0111.pb --- mindspore/lite/src/ops/ops_utils.cc | 5 +++++ .../fusion/bidirection_tf_gru_cell_fusion.cc | 14 +++++++------- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index 1e358e0f5a..82fd3b0446 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -236,6 +236,10 @@ schema::PrimitiveT *DropoutGradPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *GRUPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} schema::PrimitiveT *EltwisePrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -790,6 +794,7 @@ RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator) RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator); RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator); RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator); +RegistryMSOps g_gRUPrimitiveCreatorRegistry("GRU", GRUPrimitiveCreator); RegistryMSOps g_hashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator); RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator); RegistryMSOps g_invertPermutationPrimitiveCreatorRegistry("InvertPermutation", InvertPermutationPrimitiveCreator); diff --git a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc index 399ec240e0..730202a47d 100644 --- a/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/bidirection_tf_gru_cell_fusion.cc @@ -70,15 +70,15 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name, const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { // forward - auto fw_max1 = - VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_}); - auto fw_max2 = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMaximum)), - std::make_shared(IsParameterNode), fw_max1}); + auto fw_reduce = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), + input_length_, std::make_shared(IsParameterNode)}); + auto fw_max = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMaximum)), + std::make_shared(IsParameterNode), fw_reduce}); auto fw_shape = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimShape)), transpose_input_}); auto fw_stride = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimStridedSlice)), fw_shape, std::make_shared()}); - auto fw_min = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max2}); + auto fw_min = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max}); auto fw_reserve = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTensorListReserve)), std::make_shared(IsParameterNode), fw_stride}); @@ -100,8 +100,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { // backward auto bw_reverse_seq = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReverseSequence)), input_, input_length_}); - auto bw_max1 = - VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_}); + auto bw_max1 = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_, + std::make_shared(IsParameterNode)}); auto bw_max2 = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimMaximum)), std::make_shared(IsParameterNode), bw_max1}); auto bw_trans = VectorRef({std::make_shared(std::bind(IsOpType, p1, prim::kPrimTranspose)), bw_reverse_seq,