|
|
|
@ -17,18 +17,21 @@
|
|
|
|
|
"""The module of parser python object, called by c++."""
|
|
|
|
|
|
|
|
|
|
import ast
|
|
|
|
|
import types
|
|
|
|
|
import inspect
|
|
|
|
|
import hashlib
|
|
|
|
|
from textwrap import dedent
|
|
|
|
|
import inspect
|
|
|
|
|
import types
|
|
|
|
|
from dataclasses import is_dataclass
|
|
|
|
|
from textwrap import dedent
|
|
|
|
|
|
|
|
|
|
import asttokens
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
|
|
|
|
|
from mindspore import Tensor as MsTensor
|
|
|
|
|
from mindspore import context
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from mindspore import nn
|
|
|
|
|
from mindspore import ops
|
|
|
|
|
from mindspore.common.dtype import pytype_to_dtype
|
|
|
|
|
from mindspore.common.api import _MindSporeFunction
|
|
|
|
|
from mindspore.common.dtype import pytype_to_dtype
|
|
|
|
|
from .namespace import CellNamespace, ClosureNamespace, ClassMemberNamespace
|
|
|
|
|
from .resources import parse_object_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
|
|
|
|
|
|
|
|
@ -100,6 +103,8 @@ def get_parse_method_of_class(obj, parse_method=None):
|
|
|
|
|
else:
|
|
|
|
|
if isinstance(obj, nn.Cell):
|
|
|
|
|
if obj.enable_hook:
|
|
|
|
|
if context.get_context("mode") == context.GRAPH_MODE:
|
|
|
|
|
raise ValueError("The graph mode does not support hook function.")
|
|
|
|
|
method_name = "_hook_construct"
|
|
|
|
|
else:
|
|
|
|
|
method_name = "construct"
|
|
|
|
|