!13016 [MS][LITE]fix GRU fusion for encoder.pb

From: @mengyuanli
Reviewed-by: @hangangqiang,@zhang_xue_tong
Signed-off-by: @hangangqiang
pull/13016/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7207ca65ef

@ -236,6 +236,10 @@ schema::PrimitiveT *DropoutGradPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::DropoutGrad>>(node); auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::DropoutGrad>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
} }
schema::PrimitiveT *GRUPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::GRU>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *EltwisePrimitiveCreator(const AnfNodePtr &node) { schema::PrimitiveT *EltwisePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Eltwise>>(node); auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::Eltwise>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
@ -790,6 +794,7 @@ RegistryMSOps g_gatherPrimitiveCreatorRegistry("Gather", GatherPrimitiveCreator)
RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator); RegistryMSOps g_gatherNdPrimitiveCreatorRegistry("GatherNd", GatherNdPrimitiveCreator);
RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator); RegistryMSOps g_greaterPrimitiveCreatorRegistry("Greater", GreaterPrimitiveCreator);
RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator); RegistryMSOps g_greaterEqualPrimitiveCreatorRegistry("GreaterEqual", GreaterEqualPrimitiveCreator);
RegistryMSOps g_gRUPrimitiveCreatorRegistry("GRU", GRUPrimitiveCreator);
RegistryMSOps g_hashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator); RegistryMSOps g_hashtableLookupPrimitiveCreatorRegistry("HashtableLookup", HashtableLookupPrimitiveCreator);
RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator); RegistryMSOps g_instanceNormPrimitiveCreatorRegistry("InstanceNorm", InstanceNormPrimitiveCreator);
RegistryMSOps g_invertPermutationPrimitiveCreatorRegistry("InvertPermutation", InvertPermutationPrimitiveCreator); RegistryMSOps g_invertPermutationPrimitiveCreatorRegistry("InvertPermutation", InvertPermutationPrimitiveCreator);

@ -70,15 +70,15 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const { const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
// forward // forward
auto fw_max1 = auto fw_reduce = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)),
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_}); input_length_, std::make_shared<CondVar>(IsParameterNode)});
auto fw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)), auto fw_max = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
std::make_shared<CondVar>(IsParameterNode), fw_max1}); std::make_shared<CondVar>(IsParameterNode), fw_reduce});
auto fw_shape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape)), transpose_input_}); auto fw_shape = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimShape)), transpose_input_});
auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice)), fw_shape, auto fw_stride = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice)), fw_shape,
std::make_shared<SeqVar>()}); std::make_shared<SeqVar>()});
auto fw_min = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max2}); auto fw_min = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMinimum)), fw_stride, fw_max});
auto fw_reserve = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve)), auto fw_reserve = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTensorListReserve)),
std::make_shared<CondVar>(IsParameterNode), fw_stride}); std::make_shared<CondVar>(IsParameterNode), fw_stride});
@ -100,8 +100,8 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
// backward // backward
auto bw_reverse_seq = auto bw_reverse_seq =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReverseSequence)), input_, input_length_}); VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReverseSequence)), input_, input_length_});
auto bw_max1 = auto bw_max1 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_,
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion)), input_length_}); std::make_shared<CondVar>(IsParameterNode)});
auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)), auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMaximum)),
std::make_shared<CondVar>(IsParameterNode), bw_max1}); std::make_shared<CondVar>(IsParameterNode), bw_max1});
auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), bw_reverse_seq, auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose)), bw_reverse_seq,

Loading…
Cancel
Save