|
|
|
@ -18,6 +18,8 @@
|
|
|
|
|
import inspect
|
|
|
|
|
import copy
|
|
|
|
|
from mindspore.common.api import _wrap_func
|
|
|
|
|
from mindspore.common import Parameter
|
|
|
|
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
|
|
|
from .._c_expression import Primitive_, real_run_op, prim_type
|
|
|
|
|
from .._c_expression import signature_rw as sig_rw
|
|
|
|
|
from .._c_expression import signature_kind as sig_kind
|
|
|
|
@ -49,6 +51,7 @@ class Primitive(Primitive_):
|
|
|
|
|
self.name = name
|
|
|
|
|
self.attrs = {}
|
|
|
|
|
self.init_attrs = {"name": name}
|
|
|
|
|
self._update_parameter = False
|
|
|
|
|
Primitive_.__init__(self, name, self)
|
|
|
|
|
if hasattr(self.__class__, '__mindspore_signature__'):
|
|
|
|
|
sig = self._fill_signature(self.__class__.__mindspore_signature__)
|
|
|
|
@ -189,6 +192,11 @@ class Primitive(Primitive_):
|
|
|
|
|
# for checking output number with kernel implementation
|
|
|
|
|
self.add_prim_attr("output_names", outputs)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def update_parameter(self):
|
|
|
|
|
""" Whether the primitive will update the value of parameter."""
|
|
|
|
|
return self._update_parameter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrimitiveWithInfer(Primitive):
|
|
|
|
|
"""
|
|
|
|
@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|
|
|
|
@_wrap_func
|
|
|
|
|
def _run_op(obj, op_name, args):
|
|
|
|
|
"""Single op execution function supported by ge in PyNative mode."""
|
|
|
|
|
output = real_run_op(obj, op_name, args)
|
|
|
|
|
cast = tensor_operator_registry.get("cast")
|
|
|
|
|
if op_name == "Cast" or obj.update_parameter:
|
|
|
|
|
cast_args = args
|
|
|
|
|
else:
|
|
|
|
|
cast_args = list()
|
|
|
|
|
for arg in args:
|
|
|
|
|
if isinstance(arg, Parameter):
|
|
|
|
|
if arg.cast_type:
|
|
|
|
|
cast_args.append(cast(arg, arg.cast_type))
|
|
|
|
|
else:
|
|
|
|
|
cast_args.append(arg)
|
|
|
|
|
else:
|
|
|
|
|
cast_args.append(arg)
|
|
|
|
|
output = real_run_op(obj, op_name, tuple(cast_args))
|
|
|
|
|
if not output:
|
|
|
|
|
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
|
|
|
|
if len(output) == 1:
|
|
|
|
|