|
|
@ -25,9 +25,8 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
|
|
|
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
|
|
|
from .. import functional as F
|
|
|
|
from .. import functional as F
|
|
|
|
from ...common.parameter import Parameter
|
|
|
|
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
|
|
|
|
from .. import signature as sig
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
|
|
|
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
|
|
|
|
|
|
|
|
|
|
@ -348,6 +347,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
name (str): Operator name.
|
|
|
|
name (str): Operator name.
|
|
|
|
|
|
|
|
read_value (bool): If the registered function not need to set value on Parameter,
|
|
|
|
|
|
|
|
and all inputs will pass by value. Set `read_value` to True. Default: False.
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
Raises:
|
|
|
|
ValueError: Cannot find matching fn for the given args.
|
|
|
|
ValueError: Cannot find matching fn for the given args.
|
|
|
@ -358,16 +359,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
|
|
|
|
>>> add = MultitypeFuncGraph('add')
|
|
|
|
>>> add = MultitypeFuncGraph('add')
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, name):
|
|
|
|
def __init__(self, name, read_value=False):
|
|
|
|
MultitypeFuncGraph_.__init__(self, name)
|
|
|
|
MultitypeFuncGraph_.__init__(self, name)
|
|
|
|
self.entries = list()
|
|
|
|
self.entries = list()
|
|
|
|
|
|
|
|
if read_value:
|
|
|
|
|
|
|
|
self.set_signatures((
|
|
|
|
|
|
|
|
sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, *args):
|
|
|
|
def __call__(self, *args):
|
|
|
|
def unwrap(arg):
|
|
|
|
types = tuple(map(mstype.get_py_obj_dtype, args))
|
|
|
|
if isinstance(arg, Parameter):
|
|
|
|
|
|
|
|
return arg.data
|
|
|
|
|
|
|
|
return arg
|
|
|
|
|
|
|
|
types = tuple(map(lambda arg: mstype.get_py_obj_dtype(unwrap(arg)), args))
|
|
|
|
|
|
|
|
for sigs, fn in self.entries:
|
|
|
|
for sigs, fn in self.entries:
|
|
|
|
if len(sigs) != len(types):
|
|
|
|
if len(sigs) != len(types):
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|