!5645 [bug][api]updata signature

Merge pull request !5645 from vlne-v1/ref_demo
pull/5645/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 021ba724cf

@ -20,7 +20,6 @@ from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
def scalar_add(x, y):
"""Implement `scalar_add`."""
return x + y
@ -117,25 +116,6 @@ def bool_or(x, y):
return x or y
def vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
if obj_str == "shape":
fn = getattr(args[0].asnumpy(), obj_str)
return fn
if len(args) == 2:
fn = getattr(args[0].asnumpy(), obj_str)
return Tensor(fn())
if isinstance(args[0], Tensor):
fn = getattr(args[0].asnumpy(), obj_str)
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
else:
obj_str = "__r" + obj_str[2:]
fn = getattr(args[1].asnumpy(), obj_str)
y = args[0]
return Tensor(np.array(fn(y)))
def make_list(*xs):
"""Implement `make_list`."""
return list(xs)

@ -262,6 +262,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
std::set<size_t> write_indices;
std::vector<TypePtr> input_types;
op_inputs.push_back(NewValueNode(function));
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
// Assume, the write input of op is always the first input. We check if any write op,
// and add cast op on other inputs to keep the same type with assigned parameter.
for (size_t i = 0; i < args_spec_list.size(); ++i) {
@ -280,7 +281,6 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
TypePtr type = args_spec_list[i]->BuildType();
if (type && type->isa<RefType>()) {
auto cast_type = parse::GetMixedPrecisionTargetType(func_graph);
if (sig == SignatureEnumRW::kRWRead) {
auto source_tensor_type = type->cast<TensorTypePtr>();
if (source_tensor_type != nullptr) {
@ -300,8 +300,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter, but "
<< type->ToString();
}
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " type "
<< args_spec_list[i]->ToString();
MS_LOG(DEBUG) << "Function " << func_name << "'s input " << i << " " << param->DebugString(2) << " abs "
<< args_spec_list[i]->ToString() << " type " << type->ToString();
input_types.push_back(type);
op_inputs.push_back(param);
}

@ -305,9 +305,6 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
dic[ATTR_SHAPE] = shape;
dic[ATTR_DTYPE] = arg_slice->BuildType();
dic[ATTR_VALUE] = BuildValue(arg_slice->BuildValue());
} else if (abs_base->isa<AbstractRef>()) {
auto value = abs_base->cast<AbstractRefPtr>()->ref();
dic = ConvertAbstractToPython(value);
} else if (abs_base->isa<AbstractEllipsis>()) {
dic[ATTR_SHAPE] = py::none();
dic[ATTR_DTYPE] = py::ellipsis();

@ -23,7 +23,7 @@ namespace mindspore {
REGISTER_PYBIND_DEFINE(FuncGraph, ([](const pybind11::module *m) {
// Define python "MetaFuncGraph_" class
(void)py::class_<MetaFuncGraph, std::shared_ptr<MetaFuncGraph>>(*m, "MetaFuncGraph_")
.def(py::init<std::string &>());
.def("set_signatures", &MetaFuncGraph::set_signatures, "Set primitive inputs signature.");
// Define python "FuncGraph" class
(void)py::class_<FuncGraph, FuncGraphPtr>(*m, "FuncGraph")
.def(py::init())

@ -48,22 +48,9 @@ void SyncData(const py::object &arg) {
}
} // namespace
std::map<std::string, py::object> PrimitivePy::hook_grad_;
static ValuePtr PyArgToValue(const py::object &arg) {
if (py::isinstance<SignatureEnumKind>(arg) &&
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
return nullptr;
}
return parse::data_converter::PyDataToValue(arg);
}
void PrimitivePy::set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>> signatures) {
signatures_.clear();
for (auto &signature : signatures) {
auto [name, rw, kind, arg_default, dtype] = signature;
auto default_value = PyArgToValue(arg_default);
signatures_.emplace_back(name, rw, kind, default_value, dtype);
}
void PrimitivePy::set_signatures(const std::vector<Signature> &signatures) {
signatures_ = signatures;
set_has_signature(true);
}

@ -42,9 +42,7 @@ class PrimitivePy : public Primitive {
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction();
void set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
signatures);
void set_signatures(const std::vector<Signature> &signatures);
const std::vector<Signature> &signatures() const { return signatures_; }

@ -17,12 +17,26 @@
#include "ir/signature.h"
#include "pybind11/operators.h"
#include "pybind_api/api_register.h"
#include "pipeline/jit/parse/data_converter.h"
namespace py = pybind11;
namespace mindspore {
static ValuePtr PyArgToValue(const py::object &arg) {
if (py::isinstance<SignatureEnumKind>(arg) &&
py::cast<SignatureEnumKind>(arg) == SignatureEnumKind::kKindEmptyDefaultValue) {
return nullptr;
}
return parse::data_converter::PyDataToValue(arg);
}
// Bind SignatureEnumRW as a python class.
REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) {
(void)py::class_<Signature>(*m, "Signature")
.def(py::init([](std::string name, SignatureEnumRW rw, SignatureEnumKind kind,
py::object arg_default, SignatureEnumDType dtype) {
auto default_value = PyArgToValue(arg_default);
return Signature(name, rw, kind, default_value, dtype);
}));
(void)py::enum_<SignatureEnumRW>(*m, "signature_rw", py::arithmetic())
.value("RW_READ", SignatureEnumRW::kRWRead)
.value("RW_WRITE", SignatureEnumRW::kRWWrite)

@ -393,3 +393,24 @@ class SparseTensor:
@property
def dense_shape(self):
return self.__dense_shape
def _vm_compare(*args):
"""Implement `vm_compare` for tensor."""
obj_str = args[-1]
if obj_str == "shape":
fn = getattr(args[0].asnumpy(), obj_str)
return fn
if len(args) == 2:
fn = getattr(args[0].asnumpy(), obj_str)
return Tensor(fn())
if isinstance(args[0], Tensor):
fn = getattr(args[0].asnumpy(), obj_str)
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
else:
obj_str = "__r" + obj_str[2:]
fn = getattr(args[1].asnumpy(), obj_str)
y = args[0]
return Tensor(np.array(fn(y)))
tensor_operator_registry.register('vm_compare', _vm_compare)

@ -34,14 +34,17 @@ from .primitive import Primitive, PrimitiveWithInfer, prim_attr_register
from .vm_impl_registry import get_vm_impl_fn, vm_impl_registry
from .op_info_register import op_info_register, AkgGpuRegOp, AkgAscendRegOp, AiCPURegOp, TBERegOp, DataType
from .primitive import constexpr
from .._c_expression import signature_rw, signature_kind
from . import composite, operations, functional
from . import signature
__primitive__ = [
"prim_attr_register", "Primitive", "PrimitiveWithInfer",
"signature_rw", "signature_kind"
"prim_attr_register", "Primitive", "PrimitiveWithInfer", "signature"
]
__all__ = ["get_vm_impl_fn", "vm_impl_registry",
"op_info_register", "AkgGpuRegOp", "AkgAscendRegOp", "AiCPURegOp", "TBERegOp", "DataType",
"constexpr"]
__all__.extend(__primitive__)
__all__.extend(composite.__all__)
__all__.extend(operations.__all__)
__all__.extend(functional.__all__)

@ -25,9 +25,8 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, Mult
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_exec, _wrap_func
from .. import functional as F
from ...common.parameter import Parameter
from ...common.tensor import Tensor
from .. import signature as sig
__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
@ -348,6 +347,8 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
Args:
name (str): Operator name.
read_value (bool): If the registered function not need to set value on Parameter,
and all inputs will pass by value. Set `read_value` to True. Default: False.
Raises:
ValueError: Cannot find matching fn for the given args.
@ -358,16 +359,15 @@ class MultitypeFuncGraph(MultitypeFuncGraph_):
>>> add = MultitypeFuncGraph('add')
"""
def __init__(self, name):
def __init__(self, name, read_value=False):
MultitypeFuncGraph_.__init__(self, name)
self.entries = list()
if read_value:
self.set_signatures((
sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
def __call__(self, *args):
def unwrap(arg):
if isinstance(arg, Parameter):
return arg.data
return arg
types = tuple(map(lambda arg: mstype.get_py_obj_dtype(unwrap(arg)), args))
types = tuple(map(mstype.get_py_obj_dtype, args))
for sigs, fn in self.entries:
if len(sigs) != len(types):
continue

@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
add = base.MultitypeFuncGraph('add')
add = base.MultitypeFuncGraph('add', True)
"""`add` is a metafuncgraph object which will add two objects according to input type using ".register" decorator."""

@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
div = base.MultitypeFuncGraph("div")
div = base.MultitypeFuncGraph("div", True)
"""
div is a metafuncgraph object which will div two objects according to input type
using ".register" decorator

@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
equal = base.MultitypeFuncGraph("equal")
equal = base.MultitypeFuncGraph("equal", True)
"""
equal is a metafuncgraph object which will determine if two objects are equal according to input type
using ".register" decorator

@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
floordiv = base.MultitypeFuncGraph("floordiv")
floordiv = base.MultitypeFuncGraph("floordiv", True)
"""
`floordiv` is a metafuncgraph object which will compute the floordiv of two objects
using ".register" decorator.

@ -19,7 +19,7 @@ from .. import base
from ... import functional as F
getitem = base.MultitypeFuncGraph('getitem')
getitem = base.MultitypeFuncGraph('getitem', True)
"""
getitem is a metafuncgraph object which will get item from an object according to input type
using ".register" decorator.

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# greater_equal is a metagraph object which will determine if two objects are greater_equal according to input type
# using ".register" decorator
greater_equal = base.MultitypeFuncGraph("greater_equal")
greater_equal = base.MultitypeFuncGraph("greater_equal", True)
@greater_equal.register("Number", "Number")

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# greater is a metafuncgraph object which will determine if two objects are greater according to input type
# using ".register" decorator
greater = base.MultitypeFuncGraph("greater")
greater = base.MultitypeFuncGraph("greater", True)
@greater.register("Number", "Number")

@ -19,7 +19,7 @@ from . import _constexpr_utils as const_utils
from ... import functional as F
from ...composite import base
in_ = base.MultitypeFuncGraph("in")
in_ = base.MultitypeFuncGraph("in", True)
"""
in_ is a metafuncgraph object which will determine if a in b
using ".register" decorator

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# less_equal is a metagraph object which will determine if two objects are less_equal according to input type
# using ".register" decorator
less_equal = base.MultitypeFuncGraph("less_equal")
less_equal = base.MultitypeFuncGraph("less_equal", True)
@less_equal.register("Number", "Number")

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# less is a metafuncgraph object which will determine if two objects are less according to input type
# using ".register" decorator
less = base.MultitypeFuncGraph("less")
less = base.MultitypeFuncGraph("less", True)
@less.register("Number", "Number")

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# logical_not is a metagraph object which will generate function according to input type
# using ".register" decorator
logical_not = base.MultitypeFuncGraph("logical_not")
logical_not = base.MultitypeFuncGraph("logical_not", True)
@logical_not.register("Number")

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# logical_and is a metagraph object which will generate function according to input type
# using ".register" decorator
logical_and = base.MultitypeFuncGraph("logical_and")
logical_and = base.MultitypeFuncGraph("logical_and", True)
@logical_and.register("Number", "Number")

@ -19,7 +19,7 @@ from mindspore.ops import functional as F
# logical_or is a metagraph object which will generate function according to input type
# using ".register" decorator
logical_or = base.MultitypeFuncGraph("logical_or")
logical_or = base.MultitypeFuncGraph("logical_or", True)
@logical_or.register("Number", "Number")

@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
mod = base.MultitypeFuncGraph("mod")
mod = base.MultitypeFuncGraph("mod", True)
"""
`mod` is a metafuncgraph object which will compute the mod of two objects
using ".register" decorator.

@ -19,7 +19,7 @@ from ...composite import base
from ... import functional as F
mul = base.MultitypeFuncGraph("mul")
mul = base.MultitypeFuncGraph("mul", True)
"""
`mul` is a metafuncgraph object which will multiply two objects according to input type
using ".register" decorator.

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save