From 890857da015b02e687e024282c93969ec73f4cd7 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Tue, 9 Mar 2021 16:04:24 +0800 Subject: [PATCH] fix train bug --- mindspore/lite/src/ops/ops_utils.cc | 6 ++++++ mindspore/lite/tools/anf_exporter/anf_exporter.cc | 3 +++ mindspore/lite/tools/converter/anf_transform.cc | 2 +- .../lite/tools/optimizer/graph/redundant_op_remove_pass.cc | 4 ---- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/src/ops/ops_utils.cc b/mindspore/lite/src/ops/ops_utils.cc index 1e358e0f5a..b7e1cbbffd 100644 --- a/mindspore/lite/src/ops/ops_utils.cc +++ b/mindspore/lite/src/ops/ops_utils.cc @@ -604,6 +604,10 @@ schema::PrimitiveT *SpaceToDepthPrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; } +schema::PrimitiveT *SparseSoftmaxCrossEntropyPrimitiveCreator(const AnfNodePtr &node) { + auto ms_primc = GetValueNode>(node); + return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; +} schema::PrimitiveT *SparseToDensePrimitiveCreator(const AnfNodePtr &node) { auto ms_primc = GetValueNode>(node); return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr; @@ -871,6 +875,8 @@ RegistryMSOps g_softmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry("SoftmaxCr RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator); RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator); RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator); +RegistryMSOps g_sparseSoftmaxCrossEntropyPrimitiveCreatorRegistry("SparseSoftmaxCrossEntropyWithLogits", + SparseSoftmaxCrossEntropyPrimitiveCreator); RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator); RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator); RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator); diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index d617ad3e8c..13301cadef 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -770,6 +770,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr &input_ano } else if (value->isa()) { MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph"; return RET_OK; + } else if (value->isa()) { + MS_LOG(INFO) << "value is a monad."; + return RET_OK; } else { MS_LOG(ERROR) << "Not support value type , need add support."; return RET_ERROR; diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index d373173958..87337c3697 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -143,8 +143,8 @@ int AnfTransform::AddConvertPass(const std::shared_ptr &opt int AnfTransform::AddConstFoldPass(const std::shared_ptr &optimizer, const converter::Flags *config) { auto const_fold_pm = std::make_shared("const fold fusion pass manager", false); + const_fold_pm->AddPass(std::make_shared()); if (!config->trainModel) { - const_fold_pm->AddPass(std::make_shared()); auto inne_context_ptr = std::make_shared(); inne_context_ptr->Init(); const_fold_pm->AddPass(std::make_shared(inne_context_ptr)); diff --git a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc index 6057a1c92c..046998ba63 100644 --- a/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/redundant_op_remove_pass.cc @@ -98,10 +98,6 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) { if (CheckPrimitiveType(node, prim::kPrimUpdateState)) { status = ReplaceOp(node, manager); } - if (CheckPrimitiveType(node, prim::kPrimDepend) || - CheckPrimitiveType(node, prim::kPrimControlDepend)) { // ControlDepend delete next version. - status = ReplaceOp(node, manager); - } if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) { status = ReplaceTupleGetItem(node, manager); }