!13036 [lite]fix train bug

From: @xu_anyue
Reviewed-by: @hangangqiang,@jpc_chenjianping
Signed-off-by: @hangangqiang
pull/13036/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 7ba4f7a1dc

@ -608,6 +608,10 @@ schema::PrimitiveT *SpaceToDepthPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SpaceToDepth>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *SparseSoftmaxCrossEntropyPrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SparseSoftmaxCrossEntropy>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
}
schema::PrimitiveT *SparseToDensePrimitiveCreator(const AnfNodePtr &node) {
auto ms_primc = GetValueNode<std::shared_ptr<mindspore::ops::SparseToDense>>(node);
return ms_primc != nullptr ? ops::MSOp2SchemaOp(ms_primc.get()) : nullptr;
@ -876,6 +880,8 @@ RegistryMSOps g_softmaxCrossEntropyWithLogitsPrimitiveCreatorRegistry("SoftmaxCr
RegistryMSOps g_spaceToBatchPrimitiveCreatorRegistry("SpaceToBatch", SpaceToBatchPrimitiveCreator);
RegistryMSOps g_spaceToBatchNDPrimitiveCreatorRegistry("SpaceToBatchND", SpaceToBatchNDPrimitiveCreator);
RegistryMSOps g_spaceToDepthPrimitiveCreatorRegistry("SpaceToDepth", SpaceToDepthPrimitiveCreator);
RegistryMSOps g_sparseSoftmaxCrossEntropyPrimitiveCreatorRegistry("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyPrimitiveCreator);
RegistryMSOps g_sparseToDensePrimitiveCreatorRegistry("SparseToDense", SparseToDensePrimitiveCreator);
RegistryMSOps g_splitPrimitiveCreatorRegistry("Split", SplitPrimitiveCreator);
RegistryMSOps g_sqrtPrimitiveCreatorRegistry("Sqrt", SqrtPrimitiveCreator);

@ -770,6 +770,9 @@ int AnfExporter::ConvertInputValueNode(const std::shared_ptr<AnfNode> &input_ano
} else if (value->isa<FuncGraph>()) {
MS_LOG(INFO) << "op name:" << input_anode->fullname_with_scope() << " input is func_graph";
return RET_OK;
} else if (value->isa<Monad>()) {
MS_LOG(INFO) << "value is a monad.";
return RET_OK;
} else {
MS_LOG(ERROR) << "Not support value type , need add support.";
return RET_ERROR;

@ -143,8 +143,8 @@ int AnfTransform::AddConvertPass(const std::shared_ptr<opt::GraphOptimizer> &opt
int AnfTransform::AddConstFoldPass(const std::shared_ptr<opt::GraphOptimizer> &optimizer,
const converter::Flags *config) {
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
if (!config->trainModel) {
const_fold_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>());
if (!config->trainModel) {
auto inne_context_ptr = std::make_shared<lite::InnerContext>();
inne_context_ptr->Init();
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr));

@ -98,10 +98,6 @@ bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
if (CheckPrimitiveType(node, prim::kPrimUpdateState)) {
status = ReplaceOp(node, manager);
}
if (CheckPrimitiveType(node, prim::kPrimDepend) ||
CheckPrimitiveType(node, prim::kPrimControlDepend)) { // ControlDepend delete next version.
status = ReplaceOp(node, manager);
}
if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
status = ReplaceTupleGetItem(node, manager);
}

Loading…
Cancel
Save