!8529 fix error when cell construct return list output in pynative mode

From: @chujinjin
Reviewed-by: @kisnwang,@jjfeing
Signed-off-by: @jjfeing
pull/8529/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 1a152a9663

@ -1415,7 +1415,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar
void PynativeExecutor::set_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode,
bool is_param) {
if (!py::isinstance<py::tuple>(node)) {
if (!py::isinstance<py::tuple>(node) && !py::isinstance<py::list>(node)) {
return;
}
auto tuple = node.cast<py::tuple>();
@ -1432,7 +1432,7 @@ void PynativeExecutor::set_node_map(const FuncGraphPtr &g, const py::object &nod
void PynativeExecutor::set_tuple_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode,
const std::vector<int64_t> &idx, bool is_param) {
if (!py::isinstance<py::tuple>(node)) {
if (!py::isinstance<py::tuple>(node) && !py::isinstance<py::list>(node)) {
return;
}
auto tuple = node.cast<py::tuple>();
@ -1461,7 +1461,7 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o
auto out_id = GetId(out);
// x =op1, y =op2, return (x, y)
if (graph_info_map_[curr_g_].node_map.find(out_id) == graph_info_map_[curr_g_].node_map.end()) {
if (py::isinstance<py::tuple>(out)) {
if (py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out)) {
auto tuple = out.cast<py::tuple>();
auto tuple_size = static_cast<int64_t>(tuple.size());

@ -117,7 +117,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
// replace for grad graph
ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple);
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);

Loading…
Cancel
Save