diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 9251798fd1..17426ad4e0 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -66,23 +66,14 @@ void ValidateOperation(const AnfNodePtr &node) { MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); } -void ValidateAbstract(const AnfNodePtr &node) { - if (node == nullptr) { - MS_LOG(DEBUG) << "Node to validate is invalid"; - return; - } +bool CheckAbstractScalar(const AnfNodePtr &node) { AbstractBasePtr ptrBase = node->abstract(); - if (ptrBase == nullptr) { - MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); - return; - } - if (ptrBase->isa() || ptrBase->isa()) { - // Validate a type. - MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); - } if (ptrBase->isa()) { TypePtr ptrType = ptrBase->GetTypeTrack(); MS_EXCEPTION_IF_NULL(ptrType); + if (ptrType->isa()) { + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() << " for node=" << node->DebugString(); + } if (ptrType->isa() || ptrType->isa()) { // only send string in external if (!IsValueNode(node)) { @@ -91,26 +82,39 @@ void ValidateAbstract(const AnfNodePtr &node) { << " for node=" << node->DebugString(); } } + return true; + } + return false; +} + +void ValidateAbstract(const AnfNodePtr &node) { + if (node == nullptr) { + MS_LOG(DEBUG) << "Node to validate is invalid"; return; } - if (ptrBase->isa()) { - // NOTICE: validate dead code? - MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString(); + AbstractBasePtr ptrBase = node->abstract(); + if (ptrBase == nullptr) { + MS_LOG(DEBUG) << "Abstract is null in node: " << node->DebugString(); return; } - - if (ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || - ptrBase->isa() || ptrBase->isa() || ptrBase->isa()) { + if (ptrBase->isa() || ptrBase->isa()) { + // Validate a type. + MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString() << " for node=" << node->DebugString(); + } + if (CheckAbstractScalar(node)) { return; } - - if (ptrBase->isa()) { + if (ptrBase->isa()) { + // NOTICE: validate dead code? + MS_LOG(DEBUG) << "AbstractError in the graph: " << ptrBase->ToString(); return; } - - // UMonad or IOMonad - if (ptrBase->isa()) { + bool checkAbstractIslegal = + ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa() || ptrBase->isa() || ptrBase->isa() || + ptrBase->isa() || ptrBase->isa(); + if (checkAbstractIslegal) { return; }