From 879a519136eb3c596341c8fa47ff8de0b05cdca6 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Tue, 1 Sep 2020 23:36:13 +0800 Subject: [PATCH] updata signature --- mindspore/_extends/builtin_operations.py | 20 -- .../operator/composite/do_signature.cc | 6 +- .../pipeline/jit/static_analysis/prim.cc | 3 - .../ccsrc/pybind_api/ir/func_graph_py.cc | 2 +- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 17 +- mindspore/ccsrc/pybind_api/ir/primitive_py.h | 4 +- mindspore/ccsrc/pybind_api/ir/signature_py.cc | 14 + mindspore/common/tensor.py | 21 ++ mindspore/ops/__init__.py | 9 +- mindspore/ops/composite/base.py | 16 +- .../ops/composite/multitype_ops/add_impl.py | 2 +- .../ops/composite/multitype_ops/div_impl.py | 2 +- .../ops/composite/multitype_ops/equal_impl.py | 2 +- .../composite/multitype_ops/floordiv_impl.py | 2 +- .../composite/multitype_ops/getitem_impl.py | 2 +- .../multitype_ops/greater_equal_impl.py | 2 +- .../composite/multitype_ops/greater_impl.py | 2 +- .../ops/composite/multitype_ops/in_impl.py | 2 +- .../multitype_ops/less_equal_impl.py | 2 +- .../ops/composite/multitype_ops/less_impl.py | 2 +- .../composite/multitype_ops/logic_not_impl.py | 2 +- .../multitype_ops/logical_and_impl.py | 2 +- .../multitype_ops/logical_or_impl.py | 2 +- .../ops/composite/multitype_ops/mod_impl.py | 2 +- .../ops/composite/multitype_ops/mul_impl.py | 2 +- .../composite/multitype_ops/negative_impl.py | 2 +- .../composite/multitype_ops/not_equal_impl.py | 2 +- .../composite/multitype_ops/ones_like_impl.py | 2 +- .../ops/composite/multitype_ops/pow_impl.py | 2 +- .../ops/composite/multitype_ops/sub_impl.py | 2 +- .../ops/composite/multitype_ops/uadd_impl.py | 2 +- .../multitype_ops/zeros_like_impl.py | 2 +- mindspore/ops/functional.py | 4 +- mindspore/ops/operations/_grad_ops.py | 5 +- mindspore/ops/operations/array_ops.py | 13 +- mindspore/ops/operations/math_ops.py | 18 +- mindspore/ops/operations/nn_ops.py | 264 ++++++++---------- mindspore/ops/operations/other_ops.py | 8 +- mindspore/ops/primitive.py | 23 +- mindspore/ops/signature.py | 54 ++++ tests/ut/python/model/test_mix_precision.py | 4 +- tests/ut/python/ops/test_array_ops.py | 10 +- 42 files changed, 289 insertions(+), 270 deletions(-) create mode 100644 mindspore/ops/signature.py diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 1eade2d86d..0fd95eb13c 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -20,7 +20,6 @@ from mindspore.common.tensor import Tensor import mindspore.common.dtype as mstype from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype - def scalar_add(x, y): """Implement `scalar_add`.""" return x + y @@ -117,25 +116,6 @@ def bool_or(x, y): return x or y -def vm_compare(*args): - """Implement `vm_compare` for tensor.""" - obj_str = args[-1] - if obj_str == "shape": - fn = getattr(args[0].asnumpy(), obj_str) - return fn - if len(args) == 2: - fn = getattr(args[0].asnumpy(), obj_str) - return Tensor(fn()) - if isinstance(args[0], Tensor): - fn = getattr(args[0].asnumpy(), obj_str) - y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1] - else: - obj_str = "__r" + obj_str[2:] - fn = getattr(args[1].asnumpy(), obj_str) - y = args[0] - return Tensor(np.array(fn(y))) - - def make_list(*xs): """Implement `make_list`.""" return list(xs) diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc index 706aafe418..4ca8eaa152 100644 --- a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -262,6 +262,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func std::set write_indices; std::vector input_types; op_inputs.push_back(NewValueNode(function)); + auto cast_type = parse::GetMixedPrecisionTargetType(func_graph); // Assume, the write input of op is always the first input. We check if any write op, // and add cast op on other inputs to keep the same type with assigned parameter. for (size_t i = 0; i < args_spec_list.size(); ++i) { @@ -280,7 +281,6 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func TypePtr type = args_spec_list[i]->BuildType(); if (type && type->isa()) { - auto cast_type = parse::GetMixedPrecisionTargetType(func_graph); if (sig == SignatureEnumRW::kRWRead) { auto source_tensor_type = type->cast(); if (source_tensor_type != nullptr) { @@ -300,8 +300,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but " << type->ToString(); } - MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type " - << args_spec_list[i]->ToString(); + MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs " + << args_spec_list[i]->ToString() << " type " << type->ToString(); input_types.push_back(type); op_inputs.push_back(param); } diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 252c4c09d7..2c4b4921ab 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -305,9 +305,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic[ATTR_SHAPE] = shape; dic[ATTR_DTYPE] = arg_slice->BuildType(); dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue()); - } else if (abs_base->isa()) { - auto value = abs_base->cast()->ref(); - dic = ConvertAbstractToPython(value); } else if (abs_base->isa()) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = py::ellipsis(); diff --git a/mindspore/ccsrc/pybind_api/ir/func_graph_py.cc b/mindspore/ccsrc/pybind_api/ir/func_graph_py.cc index 768befffe3..e116647c86 100644 --- a/mindspore/ccsrc/pybind_api/ir/func_graph_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/func_graph_py.cc @@ -23,7 +23,7 @@ namespace mindspore { REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) { // Define python "MetaFuncGraph_" class (void)py::class_>(*m, "MetaFuncGraph_") - .def(py::init()); + .def("set_signatures", &MetaFuncGraph::set_signatures, "Set primitive inputs signature."); // Define python "FuncGraph" class (void)py::class_(*m, "FuncGraph") .def(py::init()) diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 4c0022bf9e..d8519191b7 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -48,22 +48,9 @@ void SyncData(const py::object &arg) { } } // namespace std::map PrimitivePy::hook_grad_; -static ValuePtr PyArgToValue(const py::object &arg) { - if (py::isinstance(arg) && - py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { - return nullptr; - } - return parse::data_converter::PyDataToValue(arg); -} -void PrimitivePy::set_signatures( - std::vector> signatures) { - signatures_.clear(); - for (auto &signature : signatures) { - auto [name, rw, kind, arg_default, dtype] = signature; - auto default_value = PyArgToValue(arg_default); - signatures_.emplace_back(name, rw, kind, default_value, dtype); - } +void PrimitivePy::set_signatures(const std::vector &signatures) { + signatures_ = signatures; set_has_signature(true); } diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.h b/mindspore/ccsrc/pybind_api/ir/primitive_py.h index 3a19f0f43d..02ac810361 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.h +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.h @@ -42,9 +42,7 @@ class PrimitivePy : public Primitive { MS_DECLARE_PARENT(PrimitivePy, Primitive); py::function GetBpropFunction(); - void set_signatures( - std::vector> - signatures); + void set_signatures(const std::vector &signatures); const std::vector &signatures() const { return signatures_; } diff --git a/mindspore/ccsrc/pybind_api/ir/signature_py.cc b/mindspore/ccsrc/pybind_api/ir/signature_py.cc index 4dbf070462..1f92b7cdff 100644 --- a/mindspore/ccsrc/pybind_api/ir/signature_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/signature_py.cc @@ -17,12 +17,26 @@ #include "ir/signature.h" #include "pybind11/operators.h" #include "pybind_api/api_register.h" +#include "pipeline/jit/parse/data_converter.h" namespace py = pybind11; namespace mindspore { +static ValuePtr PyArgToValue(const py::object &arg) { + if (py::isinstance(arg) && + py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { + return nullptr; + } + return parse::data_converter::PyDataToValue(arg); +} // Bind SignatureEnumRW as a python class. REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { + (void)py::class_(*m, "Signature") + .def(py::init([](std::string name, SignatureEnumRW rw, SignatureEnumKind kind, + py::object arg_default, SignatureEnumDType dtype) { + auto default_value = PyArgToValue(arg_default); + return Signature(name, rw, kind, default_value, dtype); + })); (void)py::enum_(*m, "signature_rw", py::arithmetic()) .value("RW_READ", SignatureEnumRW::kRWRead) .value("RW_WRITE", SignatureEnumRW::kRWWrite) diff --git a/mindspore/common/tensor.py b/mindspore/common/tensor.py index 5bacb5741b..e03e0b5b02 100644 --- a/mindspore/common/tensor.py +++ b/mindspore/common/tensor.py @@ -393,3 +393,24 @@ class SparseTensor: @property def dense_shape(self): return self.__dense_shape + + +def _vm_compare(*args): + """Implement `vm_compare` for tensor.""" + obj_str = args[-1] + if obj_str == "shape": + fn = getattr(args[0].asnumpy(), obj_str) + return fn + if len(args) == 2: + fn = getattr(args[0].asnumpy(), obj_str) + return Tensor(fn()) + if isinstance(args[0], Tensor): + fn = getattr(args[0].asnumpy(), obj_str) + y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1] + else: + obj_str = "__r" + obj_str[2:] + fn = getattr(args[1].asnumpy(), obj_str) + y = args[0] + return Tensor(np.array(fn(y))) + +tensor_operator_registry.register('vm_compare', _vm_compare) diff --git a/mindspore/ops/__init__.py b/mindspore/ops/__init__.py index aa4c5662e3..1a21f90a07 100644 --- a/mindspore/ops/__init__.py +++ b/mindspore/ops/__init__.py @@ -34,14 +34,17 @@ from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType from .primitive import constexpr -from .._c_expression import signature_rw, signature_kind +from . import composite, operations, functional +from . import signature __primitive__ = [ - "prim_attr_register", "Primitive", "PrimitiveWithInfer", - "signature_rw", "signature_kind" + "prim_attr_register", "Primitive", "PrimitiveWithInfer", "signature" ] __all__ = ["get_vm_impl_fn", "vm_impl_registry", "op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType", "constexpr"] __all__.extend(__primitive__) +__all__.extend(composite.__all__) +__all__.extend(operations.__all__) +__all__.extend(functional.__all__) diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index ec14b0f7d0..99c37c6988 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -25,9 +25,8 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult from ...common import dtype as mstype from ...common.api import ms_function, _pynative_exec, _wrap_func from .. import functional as F -from ...common.parameter import Parameter from ...common.tensor import Tensor - +from .. import signature as sig __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] @@ -348,6 +347,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): Args: name (str): Operator name. + read_value (bool): If the registered function not need to set value on Parameter, + and all inputs will pass by value. Set `read_value` to True. Default: False. Raises: ValueError: Cannot find matching fn for the given args. @@ -358,16 +359,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_): >>> add = MultitypeFuncGraph('add') """ - def __init__(self, name): + def __init__(self, name, read_value=False): MultitypeFuncGraph_.__init__(self, name) self.entries = list() + if read_value: + self.set_signatures(( + sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),)) def __call__(self, *args): - def unwrap(arg): - if isinstance(arg, Parameter): - return arg.data - return arg - types = tuple(map(lambda arg: mstype.get_py_obj_dtype(unwrap(arg)), args)) + types = tuple(map(mstype.get_py_obj_dtype, args)) for sigs, fn in self.entries: if len(sigs) != len(types): continue diff --git a/mindspore/ops/composite/multitype_ops/add_impl.py b/mindspore/ops/composite/multitype_ops/add_impl.py index 2ad81bfc93..6a6ca9fe4d 100644 --- a/mindspore/ops/composite/multitype_ops/add_impl.py +++ b/mindspore/ops/composite/multitype_ops/add_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -add = base.MultitypeFuncGraph('add') +add = base.MultitypeFuncGraph('add', True) """`add` is a metafuncgraph object which will add two objects according to input type using ".register" decorator.""" diff --git a/mindspore/ops/composite/multitype_ops/div_impl.py b/mindspore/ops/composite/multitype_ops/div_impl.py index 85a4e035c0..b1a7f040f1 100644 --- a/mindspore/ops/composite/multitype_ops/div_impl.py +++ b/mindspore/ops/composite/multitype_ops/div_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -div = base.MultitypeFuncGraph("div") +div = base.MultitypeFuncGraph("div", True) """ div is a metafuncgraph object which will div two objects according to input type using ".register" decorator diff --git a/mindspore/ops/composite/multitype_ops/equal_impl.py b/mindspore/ops/composite/multitype_ops/equal_impl.py index 97e647bc7c..e39f5af332 100644 --- a/mindspore/ops/composite/multitype_ops/equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/equal_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -equal = base.MultitypeFuncGraph("equal") +equal = base.MultitypeFuncGraph("equal", True) """ equal is a metafuncgraph object which will determine if two objects are equal according to input type using ".register" decorator diff --git a/mindspore/ops/composite/multitype_ops/floordiv_impl.py b/mindspore/ops/composite/multitype_ops/floordiv_impl.py index 8e9e941309..238867f457 100644 --- a/mindspore/ops/composite/multitype_ops/floordiv_impl.py +++ b/mindspore/ops/composite/multitype_ops/floordiv_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -floordiv = base.MultitypeFuncGraph("floordiv") +floordiv = base.MultitypeFuncGraph("floordiv", True) """ `floordiv` is a metafuncgraph object which will compute the floordiv of two objects using ".register" decorator. diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 48e4c71ca6..24f2f5fbff 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -19,7 +19,7 @@ from .. import base from ... import functional as F -getitem = base.MultitypeFuncGraph('getitem') +getitem = base.MultitypeFuncGraph('getitem', True) """ getitem is a metafuncgraph object which will get item from an object according to input type using ".register" decorator. diff --git a/mindspore/ops/composite/multitype_ops/greater_equal_impl.py b/mindspore/ops/composite/multitype_ops/greater_equal_impl.py index 93f1acbc54..4d2bd78e30 100644 --- a/mindspore/ops/composite/multitype_ops/greater_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/greater_equal_impl.py @@ -19,7 +19,7 @@ from mindspore.ops import functional as F # greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type # using ".register" decorator -greater_equal = base.MultitypeFuncGraph("greater_equal") +greater_equal = base.MultitypeFuncGraph("greater_equal", True) @greater_equal.register("Number", "Number") diff --git a/mindspore/ops/composite/multitype_ops/greater_impl.py b/mindspore/ops/composite/multitype_ops/greater_impl.py index 2f3a2dbb83..35a8daa43a 100644 --- a/mindspore/ops/composite/multitype_ops/greater_impl.py +++ b/mindspore/ops/composite/multitype_ops/greater_impl.py @@ -19,7 +19,7 @@ from mindspore.ops import functional as F # greater is a metafuncgraph object which will determine if two objects are greater according to input type # using ".register" decorator -greater = base.MultitypeFuncGraph("greater") +greater = base.MultitypeFuncGraph("greater", True) @greater.register("Number", "Number") diff --git a/mindspore/ops/composite/multitype_ops/in_impl.py b/mindspore/ops/composite/multitype_ops/in_impl.py index 26f7bce437..afa42a1239 100644 --- a/mindspore/ops/composite/multitype_ops/in_impl.py +++ b/mindspore/ops/composite/multitype_ops/in_impl.py @@ -19,7 +19,7 @@ from . import _constexpr_utils as const_utils from ... import functional as F from ...composite import base -in_ = base.MultitypeFuncGraph("in") +in_ = base.MultitypeFuncGraph("in", True) """ in_ is a metafuncgraph object which will determine if a in b using ".register" decorator diff --git a/mindspore/ops/composite/multitype_ops/less_equal_impl.py b/mindspore/ops/composite/multitype_ops/less_equal_impl.py index 5927c4b349..13f229b8e7 100644 --- a/mindspore/ops/composite/multitype_ops/less_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/less_equal_impl.py @@ -19,7 +19,7 @@ from mindspore.ops import functional as F # less_equal is a metagraph object which will determine if two objects are less_equal according to input type # using ".register" decorator -less_equal = base.MultitypeFuncGraph("less_equal") +less_equal = base.MultitypeFuncGraph("less_equal", True) @less_equal.register("Number", "Number") diff --git a/mindspore/ops/composite/multitype_ops/less_impl.py b/mindspore/ops/composite/multitype_ops/less_impl.py index 6e50e54c82..0fc9f72417 100644 --- a/mindspore/ops/composite/multitype_ops/less_impl.py +++ b/mindspore/ops/composite/multitype_ops/less_impl.py @@ -19,7 +19,7 @@ from mindspore.ops import functional as F # less is a metafuncgraph object which will determine if two objects are less according to input type # using ".register" decorator -less = base.MultitypeFuncGraph("less") +less = base.MultitypeFuncGraph("less", True) @less.register("Number", "Number") diff --git a/mindspore/ops/composite/multitype_ops/logic_not_impl.py b/mindspore/ops/composite/multitype_ops/logic_not_impl.py index 73219afec1..842e19242c 100644 --- a/mindspore/ops/composite/multitype_ops/logic_not_impl.py +++ b/mindspore/ops/composite/multitype_ops/logic_not_impl.py @@ -19,7 +19,7 @@ from mindspore.ops import functional as F # logical_not is a metagraph object which will generate function according to input type # using ".register" decorator -logical_not = base.MultitypeFuncGraph("logical_not") +logical_not = base.MultitypeFuncGraph("logical_not", True) @logical_not.register("Number") diff --git a/mindspore/ops/composite/multitype_ops/logical_and_impl.py b/mindspore/ops/composite/multitype_ops/logical_and_impl.py index 79001f43e8..d287a88cf2 100644 --- a/mindspore/ops/composite/multitype_ops/logical_and_impl.py +++ b/mindspore/ops/composite/multitype_ops/logical_and_impl.py @@ -19,7 +19,7 @@ from mindspore.ops import functional as F # logical_and is a metagraph object which will generate function according to input type # using ".register" decorator -logical_and = base.MultitypeFuncGraph("logical_and") +logical_and = base.MultitypeFuncGraph("logical_and", True) @logical_and.register("Number", "Number") diff --git a/mindspore/ops/composite/multitype_ops/logical_or_impl.py b/mindspore/ops/composite/multitype_ops/logical_or_impl.py index 6d070d5cbf..e6b11fdc9b 100644 --- a/mindspore/ops/composite/multitype_ops/logical_or_impl.py +++ b/mindspore/ops/composite/multitype_ops/logical_or_impl.py @@ -19,7 +19,7 @@ from mindspore.ops import functional as F # logical_or is a metagraph object which will generate function according to input type # using ".register" decorator -logical_or = base.MultitypeFuncGraph("logical_or") +logical_or = base.MultitypeFuncGraph("logical_or", True) @logical_or.register("Number", "Number") diff --git a/mindspore/ops/composite/multitype_ops/mod_impl.py b/mindspore/ops/composite/multitype_ops/mod_impl.py index 4b6a13bbc8..3becc239e0 100644 --- a/mindspore/ops/composite/multitype_ops/mod_impl.py +++ b/mindspore/ops/composite/multitype_ops/mod_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -mod = base.MultitypeFuncGraph("mod") +mod = base.MultitypeFuncGraph("mod", True) """ `mod` is a metafuncgraph object which will compute the mod of two objects using ".register" decorator. diff --git a/mindspore/ops/composite/multitype_ops/mul_impl.py b/mindspore/ops/composite/multitype_ops/mul_impl.py index b5535df135..4aeb4aa5f5 100644 --- a/mindspore/ops/composite/multitype_ops/mul_impl.py +++ b/mindspore/ops/composite/multitype_ops/mul_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -mul = base.MultitypeFuncGraph("mul") +mul = base.MultitypeFuncGraph("mul", True) """ `mul` is a metafuncgraph object which will multiply two objects according to input type using ".register" decorator. diff --git a/mindspore/ops/composite/multitype_ops/negative_impl.py b/mindspore/ops/composite/multitype_ops/negative_impl.py index 81fc52e332..78c924c5fa 100644 --- a/mindspore/ops/composite/multitype_ops/negative_impl.py +++ b/mindspore/ops/composite/multitype_ops/negative_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -negative = base.MultitypeFuncGraph("negative") +negative = base.MultitypeFuncGraph("negative", True) """ `negative` is a metafuncgraph object which will give the negative of an object according to its input type using ".register" decorator. diff --git a/mindspore/ops/composite/multitype_ops/not_equal_impl.py b/mindspore/ops/composite/multitype_ops/not_equal_impl.py index 7196f370cb..f5aec4bbaf 100644 --- a/mindspore/ops/composite/multitype_ops/not_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/not_equal_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -not_equal = base.MultitypeFuncGraph("not_equal") +not_equal = base.MultitypeFuncGraph("not_equal", True) """ not_equal is a metafuncgraph object which will determine if two objects are not_equal according to input type using ".register" decorator diff --git a/mindspore/ops/composite/multitype_ops/ones_like_impl.py b/mindspore/ops/composite/multitype_ops/ones_like_impl.py index 840571c8b1..dcb8e29430 100644 --- a/mindspore/ops/composite/multitype_ops/ones_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/ones_like_impl.py @@ -22,7 +22,7 @@ from ... import functional as F from ... import operations as P -ones_like_leaf = base.MultitypeFuncGraph('ones_like_leaf') +ones_like_leaf = base.MultitypeFuncGraph('ones_like_leaf', True) """ `ones_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type using ".register" decorator. diff --git a/mindspore/ops/composite/multitype_ops/pow_impl.py b/mindspore/ops/composite/multitype_ops/pow_impl.py index 8d73335c98..76b4c62d61 100644 --- a/mindspore/ops/composite/multitype_ops/pow_impl.py +++ b/mindspore/ops/composite/multitype_ops/pow_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -pow_ = base.MultitypeFuncGraph("pow") +pow_ = base.MultitypeFuncGraph("pow", True) """ `pow` is a metafuncgraph object which will compute the pow of two objects using ".register" decorator. diff --git a/mindspore/ops/composite/multitype_ops/sub_impl.py b/mindspore/ops/composite/multitype_ops/sub_impl.py index 864b8678d4..430a360955 100644 --- a/mindspore/ops/composite/multitype_ops/sub_impl.py +++ b/mindspore/ops/composite/multitype_ops/sub_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -sub = base.MultitypeFuncGraph("sub") +sub = base.MultitypeFuncGraph("sub", True) """ `sub` is a metafuncgraph object which will compute the subtraction of two objects using ".register" decorator. diff --git a/mindspore/ops/composite/multitype_ops/uadd_impl.py b/mindspore/ops/composite/multitype_ops/uadd_impl.py index 163120b541..53e4e087a5 100644 --- a/mindspore/ops/composite/multitype_ops/uadd_impl.py +++ b/mindspore/ops/composite/multitype_ops/uadd_impl.py @@ -18,7 +18,7 @@ from mindspore.ops.composite import base # uadd is a metagraph object which will return operation result regarding input # using ".register" decorator -uadd = base.MultitypeFuncGraph("uadd") +uadd = base.MultitypeFuncGraph("uadd", True) @uadd.register("Tensor") @uadd.register("Number") diff --git a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py index 9732d84fdc..7872952aec 100644 --- a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py @@ -19,7 +19,7 @@ from ...composite import base from ... import functional as F -zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf') +zeros_like_leaf = base.MultitypeFuncGraph('zeros_like_leaf', True) """ `zeros_like_leaf` is a metafuncgraph object which will generate a tensor filled with one according to its input type using ".register" decorator. diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index bbabc03733..3b9a167110 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -21,7 +21,6 @@ from mindspore.common._register_for_tensor import tensor_operator_registry from .primitive import Primitive from . import operations as P from .operations import _grad_ops -from .._extends import builtin_operations as BP typeof = Primitive('typeof') hastype = Primitive('hastype') @@ -182,5 +181,6 @@ tensor_operator_registry.register('__gt__', tensor_gt) tensor_operator_registry.register('__ge__', tensor_ge) tensor_operator_registry.register('shape', shape) # support GE backend for no compare operators -tensor_operator_registry.register('vm_compare', BP.vm_compare) tensor_operator_registry.register('cast', cast) + +__all__ = [name for name in dir() if name[0] != "_"] diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index e40b08be25..c16ef4872a 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -15,8 +15,7 @@ """Operators for gradients.""" -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import signature_kind as sig_kind +from .. import signature as sig from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register from ..._checkparam import Validator as validator, Rel from .._utils import get_concat_offset @@ -1500,7 +1499,7 @@ class RefToEmbed(Primitive): >>> return key, self.weight """ __mindspore_signature__ = ( - ('variable', sig_rw.RW_REF, sig_kind.KIND_POSITIONAL_KEYWORD), + sig.make_sig('variable', sig.sig_rw.RW_REF), ) @prim_attr_register diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 8f0a9ba427..fff51ef309 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -28,10 +28,7 @@ import numpy as np from .._utils import get_concat_offset from ..operations.math_ops import _infer_shape_reduce from ..primitive import Primitive, PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op -from ..._c_expression import signature_dtype as sig_dtype -from ..._c_expression import signature_kind as sig_kind -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import typing +from .. import signature as sig from ..._checkparam import Rel from ..._checkparam import Validator as validator from ...common import dtype as mstype @@ -44,9 +41,9 @@ class _ScatterOp(PrimitiveWithInfer): Define Scatter operators """ __mindspore_signature__ = ( - ('x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('updates', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('updates', dtype=sig.sig_dtype.T) ) def _check_scatter_shape(self, x_shape, indices_shape, updates_shape, prim_name): @@ -1396,7 +1393,7 @@ class Tile(PrimitiveWithInfer): validator.check_value_type("shape", multiples_v, [tuple], self.name) for i, multiple in enumerate(multiples_v): validator.check_value_type("multiples[%d]" % i, multiple, [int], self.name) - validator.check_value_type("x[\'dtype\']", x["dtype"], typing.TensorType, self.name) + validator.check_value_type("x[\'dtype\']", x["dtype"], mstype.tensor_type, self.name) len_sub = len(multiples_v) - len(x_shp) multiples_w = None if len_sub == 0: diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index ad2cc1c6fc..babb511385 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -18,9 +18,7 @@ import copy import numpy as np from ... import context -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import signature_kind as sig_kind -from ..._c_expression import signature_dtype as sig_dtype +from .. import signature as sig from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype @@ -68,7 +66,7 @@ class _BinaryOp(PrimitiveWithInfer): Define binary operators. """ - __mindspore_signature__ = (sig_dtype.T, sig_dtype.T) + __mindspore_signature__ = (sig.sig_dtype.T, sig.sig_dtype.T) @prim_attr_register def __init__(self): @@ -186,8 +184,8 @@ class AssignAdd(PrimitiveWithInfer): >>> net(value) """ __mindspore_signature__ = ( - ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('x', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('value', dtype=sig.sig_dtype.T) ) @prim_attr_register @@ -237,8 +235,8 @@ class AssignSub(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('value', dtype=sig.sig_dtype.T) ) @prim_attr_register @@ -264,8 +262,8 @@ class _Reduce(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), - ('axis', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, ()), + sig.make_sig('input_x'), + sig.make_sig('axis', default=()) ) @prim_attr_register diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 2c94af7847..903940f4b7 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -22,9 +22,7 @@ from functools import reduce import numpy as np from ... import context -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import signature_kind as sig_kind -from ..._c_expression import signature_dtype as sig_dtype +from .. import signature as sig from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype @@ -679,11 +677,11 @@ class FusedBatchNormEx(PrimitiveWithInfer): >>> output = op(input_x, scale, bias, mean, variance) """ __mindspore_signature__ = ( - ('input_x', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), - ('scale', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('bias', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('mean', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('variance', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + sig.make_sig('input_x', dtype=sig.sig_dtype.T2), + sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -1722,13 +1720,11 @@ class ApplyMomentum(PrimitiveWithInfer): Please refer to the usage in nn.ApplyMomentum. """ __mindspore_signature__ = ( - ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T1), - ('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2) + sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1), + sig.make_sig('gradient', dtype=sig.sig_dtype.T), + sig.make_sig('momentum', dtype=sig.sig_dtype.T2), ) @prim_attr_register @@ -3146,23 +3142,17 @@ class FusedSparseAdam(PrimitiveWithInfer): >>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices) """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('beta2_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('beta1_power', dtype=sig.sig_dtype.T), + sig.make_sig('beta2_power', dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T), + sig.make_sig('beta1', dtype=sig.sig_dtype.T), + sig.make_sig('beta2', dtype=sig.sig_dtype.T), + sig.make_sig('epsilon', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), ) @prim_attr_register @@ -3285,23 +3275,17 @@ class FusedSparseLazyAdam(PrimitiveWithInfer): >>> result = net(beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices) """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('beta2_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('beta1_power', dtype=sig.sig_dtype.T), + sig.make_sig('beta2_power', dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T), + sig.make_sig('beta1', dtype=sig.sig_dtype.T), + sig.make_sig('beta2', dtype=sig.sig_dtype.T), + sig.make_sig('epsilon', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), ) @prim_attr_register @@ -3394,11 +3378,11 @@ class FusedSparseFtrl(PrimitiveWithInfer): >>> output = net(grad, indices) """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), ) @prim_attr_register @@ -3492,13 +3476,13 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): >>> output = net(grad, indices) """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T), + sig.make_sig('l1', dtype=sig.sig_dtype.T), + sig.make_sig('l2', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), ) @prim_attr_register @@ -3754,16 +3738,15 @@ class ApplyAdaMax(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('v', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('beta1_power', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T1), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), - ('beta1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), - ('beta2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4), - ('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T5), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('v', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('beta1_power', dtype=sig.sig_dtype.T1), + sig.make_sig('lr', dtype=sig.sig_dtype.T2), + sig.make_sig('beta1', dtype=sig.sig_dtype.T3), + sig.make_sig('beta2', dtype=sig.sig_dtype.T4), + sig.make_sig('epsilon', dtype=sig.sig_dtype.T5), + sig.make_sig('grad', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -3873,14 +3856,13 @@ class ApplyAdadelta(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum_update', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('rho', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), - ('epsilon', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum_update', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T1), + sig.make_sig('rho', dtype=sig.sig_dtype.T2), + sig.make_sig('epsilon', dtype=sig.sig_dtype.T3), + sig.make_sig('grad', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -3971,10 +3953,10 @@ class ApplyAdagrad(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T1), + sig.make_sig('grad', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -4054,10 +4036,10 @@ class ApplyAdagradV2(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T1), + sig.make_sig('grad', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -4137,10 +4119,10 @@ class SparseApplyAdagrad(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), ) @prim_attr_register @@ -4224,10 +4206,10 @@ class SparseApplyAdagradV2(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), ) @prim_attr_register @@ -4313,12 +4295,12 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), - ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T1), + sig.make_sig('l1', dtype=sig.sig_dtype.T2), + sig.make_sig('l2', dtype=sig.sig_dtype.T3), + sig.make_sig('grad', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -4418,13 +4400,13 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), - ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T1), + sig.make_sig('l1', dtype=sig.sig_dtype.T2), + sig.make_sig('l2', dtype=sig.sig_dtype.T3), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T4), ) @prim_attr_register @@ -4508,14 +4490,13 @@ class ApplyAddSign(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), - ('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T3), - ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T1), + sig.make_sig('alpha', dtype=sig.sig_dtype.T2), + sig.make_sig('sign_decay', dtype=sig.sig_dtype.T3), + sig.make_sig('beta', dtype=sig.sig_dtype.T3), + sig.make_sig('grad', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -4618,14 +4599,13 @@ class ApplyPowerSign(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T), - ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('m', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('lr', dtype=sig.sig_dtype.T), + sig.make_sig('logbase', dtype=sig.sig_dtype.T), + sig.make_sig('sign_decay', dtype=sig.sig_dtype.T), + sig.make_sig('beta', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -4704,9 +4684,9 @@ class ApplyGradientDescent(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('delta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('alpha', dtype=sig.sig_dtype.T1), + sig.make_sig('delta', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -4777,11 +4757,11 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('alpha', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('l1', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), - ('l2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T3), - ('delta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('alpha', dtype=sig.sig_dtype.T1), + sig.make_sig('l1', dtype=sig.sig_dtype.T2), + sig.make_sig('l2', dtype=sig.sig_dtype.T3), + sig.make_sig('delta', dtype=sig.sig_dtype.T), ) @prim_attr_register @@ -5032,11 +5012,11 @@ class SparseApplyFtrl(PrimitiveWithCheck): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), ) @prim_attr_register @@ -5126,11 +5106,11 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): """ __mindspore_signature__ = ( - ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('accum', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('linear', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1) + sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('indices', dtype=sig.sig_dtype.T1), ) @prim_attr_register diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index e0c1e484d4..9d24ff8b44 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -15,9 +15,7 @@ """Other operators.""" import functools -from ..._c_expression import signature_rw as sig_rw -from ..._c_expression import signature_kind as sig_kind -from ..._c_expression import signature_dtype as sig_dtype +from .. import signature as sig from ..._checkparam import Validator as validator, Rel from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register @@ -53,8 +51,8 @@ class Assign(Primitive): >>> net(x) """ __mindspore_signature__ = ( - ('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) + sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), + sig.make_sig('value', dtype=sig.sig_dtype.T) ) @prim_attr_register diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 963c9762ba..2c6af7d0b5 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -14,17 +14,13 @@ # ============================================================================ """primitive""" - import inspect import copy from mindspore.common.api import _wrap_func from mindspore.common._register_for_tensor import tensor_operator_registry from mindspore import context from .._c_expression import Primitive_, real_run_op, prim_type -from .._c_expression import signature_rw as sig_rw -from .._c_expression import signature_kind as sig_kind -from .._c_expression import signature_dtype as sig_dtype - +from . import signature as sig class Primitive(Primitive_): """ @@ -54,24 +50,21 @@ class Primitive(Primitive_): self._update_parameter = False Primitive_.__init__(self, name, self) if hasattr(self.__class__, '__mindspore_signature__'): - sig = self._fill_signature(self.__class__.__mindspore_signature__) - self.set_signatures(sig) + out = self._fill_signature(self.__class__.__mindspore_signature__) + self.set_signatures(out) def _fill_signature(self, signatures): """fills signature.""" signatures_new = [] for signature in signatures: - if isinstance(signature, sig_dtype): - signatures_new.append(("argument", sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, - sig_kind.KIND_EMPTY_DEFAULT_VALUE, signature)) + if isinstance(signature, sig.Signature): + signatures_new.append(signature) + elif isinstance(signature, sig.sig_dtype): + signatures_new.append(sig.make_sig(dtype=signature)) else: if len(signature) < 3: raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}") - if len(signature) == 3: - signature += (sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T_EMPTY_DEFAULT_VALUE) - if len(signature) == 4: - signature += (sig_dtype.T_EMPTY_DEFAULT_VALUE,) - signatures_new.append(signature) + signatures_new.append(sig.make_sig(*signature)) return tuple(signatures_new) def _clone(self): diff --git a/mindspore/ops/signature.py b/mindspore/ops/signature.py new file mode 100644 index 0000000000..60debed881 --- /dev/null +++ b/mindspore/ops/signature.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""signature""" + +from .._c_expression import signature_rw as sig_rw +from .._c_expression import signature_kind as sig_kind +from .._c_expression import signature_dtype as sig_dtype +from .._c_expression import Signature + + +def make_sig(name="var", rw=sig_rw.RW_READ, + kind=sig_kind.KIND_POSITIONAL_KEYWORD, + default=sig_kind.KIND_EMPTY_DEFAULT_VALUE, + dtype=sig_dtype.T_EMPTY_DEFAULT_VALUE): + """ + Make signature for one argument. + + See `ApplyMomentum` in `mindspore.ops.operation.nn_ops` as a example. + + Args: + name (bool): Argument name. Default: "var". + rw (:class:`mindspore.ops.signature.sig_rw`): Tag the argument attribute for write and read. Choose in + [sig_rw.RW_READ, sig_rw.RW_WRITE, sig_rw.RW_REF]`, tag if the argument will update the input. + `sig_rw.RW_READ` for read only argument and `sig_rw.RW_WRITE` for write only argument. `sig_rw.RW_READ` + for the argument both need read and write. Default: sig_rw.RW_READ. + kind (:class:`mindspore.ops.signature.kind`): Choose in `[signature_kind.KIND_POSITIONAL_KEYWORD, + signature_kind.KIND_VAR_POSITIONAL, signature_kind.KIND_KEYWORD_ONLY, signature_kind.KIND_VAR_KEYWARD]`. + The meaning is the same as python argument kind, please refer to the python document. + Default: sig_kind.KIND_POSITIONAL_KEYWORD. + default (Any): The default value of argument or `sig_kind.KIND_EMPTY_DEFAULT_VALUE` for no default value. + Default: sig_kind.KIND_EMPTY_DEFAULT_VALUE. + dtype (:class:`mindspore.ops.signature.sig_dtype`): Choose in `signature_dtype.T` or + `signature_dtype.T1` to `signature_dtype.T9` or `sig_dtype.T_EMPTY_DEFAULT_VALUE` for no constraints. + If the signature of one argument is the same as another argument, we will perform auto type convert + between them. If any `sig_rw.RW_WRITE` argument, we will try to convert the other arguments to the + `sig_rw.RW_WRITE` argument. Default: sig_dtype.T_EMPTY_DEFAULT_VALUE. + + Returns: + :class:`mindspore.ops.signature.Signature`, signature for one argument. + """ + return Signature(name, rw, kind, default, dtype) diff --git a/tests/ut/python/model/test_mix_precision.py b/tests/ut/python/model/test_mix_precision.py index 9faf7341b6..cfed2beb27 100644 --- a/tests/ut/python/model/test_mix_precision.py +++ b/tests/ut/python/model/test_mix_precision.py @@ -136,13 +136,15 @@ class NetForCast(nn.Cell): super(NetForCast, self).__init__() self.concat = P.Concat() self.x1 = Tensor(1.0, mstype.float32) + self.x2 = Parameter(Tensor(np.zeros([1, 10]).astype(np.float32)), name='x2') def construct(self, x0): - x = self.x1 * x0 + x = self.x1 * x0 * self.x2 return x def test_cast(): + context.set_context(save_graphs=True) x = Tensor(np.ones([1, 16, 10, 10]).astype(np.float32) * 0.01) net = NetForCast() net.add_flags_recursive(fp16=True) diff --git a/tests/ut/python/ops/test_array_ops.py b/tests/ut/python/ops/test_array_ops.py index ae71579973..5dbdfe42e8 100644 --- a/tests/ut/python/ops/test_array_ops.py +++ b/tests/ut/python/ops/test_array_ops.py @@ -16,9 +16,7 @@ import functools import numpy as np import pytest -from mindspore._c_expression import signature_dtype as sig_dtype -from mindspore._c_expression import signature_kind as sig_kind -from mindspore._c_expression import signature_rw as sig_rw +from mindspore.ops.signature import sig_rw, sig_dtype, make_sig import mindspore as ms from mindspore import Tensor @@ -126,9 +124,9 @@ class CustomOP(PrimitiveWithInfer): class CustomOP2(PrimitiveWithInfer): __mindspore_signature__ = ( - ('p1', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('p2', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('p3', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + make_sig('p1', sig_rw.RW_WRITE, dtype=sig_dtype.T), + make_sig('p2', dtype=sig_dtype.T), + make_sig('p3', dtype=sig_dtype.T), ) @prim_attr_register