|
|
|
@ -661,6 +661,20 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
|
|
|
|
|
// out = op(cell1(x, y))
|
|
|
|
|
// out = op(cell1(x, y)[0])
|
|
|
|
|
node = GetObjNode(obj);
|
|
|
|
|
} else if (py::isinstance<py::tuple>(obj)) {
|
|
|
|
|
// out = op((x, y))
|
|
|
|
|
// out = cell((x, y))
|
|
|
|
|
std::vector<AnfNodePtr> args;
|
|
|
|
|
args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
|
|
|
|
|
auto tuple = obj.cast<py::tuple>();
|
|
|
|
|
auto tuple_size = static_cast<int>(tuple.size());
|
|
|
|
|
for (int i = 0; i < tuple_size; i++) {
|
|
|
|
|
args.push_back(GetInput(tuple[i], py::object()));
|
|
|
|
|
}
|
|
|
|
|
auto cnode = curr_g_->NewCNode(args);
|
|
|
|
|
set_obj_node_map(curr_g_, GetId(obj), cnode);
|
|
|
|
|
node = cnode;
|
|
|
|
|
} else {
|
|
|
|
|
// out = op(x, 1)
|
|
|
|
|
ValuePtr converted_ret = nullptr;
|
|
|
|
@ -728,6 +742,13 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
|
|
|
|
}
|
|
|
|
|
auto out_cnode = curr_g_->NewCNode(inputs);
|
|
|
|
|
set_pyobj(curr_g_, GetId(cell));
|
|
|
|
|
if (py::isinstance<py::tuple>(out)) {
|
|
|
|
|
auto out_list = py::cast<py::tuple>(out);
|
|
|
|
|
auto out_size = static_cast<int>(out_list.size());
|
|
|
|
|
for (int i = 0; i < out_size; i++) {
|
|
|
|
|
set_obj_node_map(curr_g_, GetId(out_list[i]), out_cnode, i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
set_obj_node_map(curr_g_, GetId(out), out_cnode);
|
|
|
|
|
} else {
|
|
|
|
|
parse::ResolveFuncGraph(newfg, resource_);
|
|
|
|
|