!1878 fix tuple args issue in pynative mode

Merge pull request !1878 from wangqiuliang/fix-tuple-args-issue-in-pynative
pull/1878/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ed77c761ec

@ -661,6 +661,20 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
// out = op(cell1(x, y))
// out = op(cell1(x, y)[0])
node = GetObjNode(obj);
} else if (py::isinstance<py::tuple>(obj)) {
// out = op((x, y))
// out = cell((x, y))
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));
auto tuple = obj.cast<py::tuple>();
auto tuple_size = static_cast<int>(tuple.size());
for (int i = 0; i < tuple_size; i++) {
args.push_back(GetInput(tuple[i], py::object()));
}
auto cnode = curr_g_->NewCNode(args);
set_obj_node_map(curr_g_, GetId(obj), cnode);
node = cnode;
} else {
// out = op(x, 1)
ValuePtr converted_ret = nullptr;
@ -728,6 +742,13 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
}
auto out_cnode = curr_g_->NewCNode(inputs);
set_pyobj(curr_g_, GetId(cell));
if (py::isinstance<py::tuple>(out)) {
auto out_list = py::cast<py::tuple>(out);
auto out_size = static_cast<int>(out_list.size());
for (int i = 0; i < out_size; i++) {
set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
}
}
set_obj_node_map(curr_g_, GetId(out), out_cnode);
} else {
parse::ResolveFuncGraph(newfg, resource_);

@ -19,6 +19,7 @@ from .._c_expression import Tensor as Tensor_
from .._c_expression import MetaTensor
from .._checkparam import check_type, check_typename
from . import dtype as mstype
from .. import context
from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor']
@ -77,6 +78,8 @@ class Tensor(Tensor_):
def __eq__(self, other):
if not isinstance(other, Tensor):
return False
if context.get_context("enable_ge") or self.dtype() == mstype.bool_ or other.dtype() == mstype.bool_:
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
return tensor_operator_registry.get('__eq__')(self, other)
def __ne__(self, other):

@ -145,8 +145,8 @@ class SameTypeShape(PrimitiveWithInfer):
def __call__(self, x, y):
"""run in PyNative mode"""
validator.check_subclass('x', x.dtype(), mstype.tensor, self.name)
validator.check_subclass('y', y.dtype(), mstype.tensor, self.name)
validator.check_value_type("x", x, Tensor, self.name)
validator.check_value_type("y", y, Tensor, self.name)
validator.check('x dtype', x.dtype(), 'y dtype', y.dtype(), Rel.EQ, self.name, TypeError)
validator.check('x shape', x.shape(), 'y shape', y.shape(), Rel.EQ, self.name)
return x

Loading…
Cancel
Save