From db7837eb5de40961480d9491041cebe0fbd651ba Mon Sep 17 00:00:00 2001 From: chujinjin Date: Thu, 12 Nov 2020 20:13:57 +0800 Subject: [PATCH] fix error when cell construct return list in pynative mode --- mindspore/ccsrc/pipeline/pynative/pynative_execute.cc | 6 +++--- mindspore/ccsrc/pipeline/pynative/pynative_execute.h | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index fe0ed330c1..ba94601cdc 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -1415,7 +1415,7 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar void PynativeExecutor::set_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, bool is_param) { - if (!py::isinstance(node)) { + if (!py::isinstance(node) && !py::isinstance(node)) { return; } auto tuple = node.cast(); @@ -1432,7 +1432,7 @@ void PynativeExecutor::set_node_map(const FuncGraphPtr &g, const py::object &nod void PynativeExecutor::set_tuple_node_map(const FuncGraphPtr &g, const py::object &node, const AnfNodePtr &cnode, const std::vector &idx, bool is_param) { - if (!py::isinstance(node)) { + if (!py::isinstance(node) && !py::isinstance(node)) { return; } auto tuple = node.cast(); @@ -1461,7 +1461,7 @@ void PynativeExecutor::EndGraphInner(const py::object &cell, const py::object &o auto out_id = GetId(out); // x =op1, y =op2, return (x, y) if (graph_info_map_[curr_g_].node_map.find(out_id) == graph_info_map_[curr_g_].node_map.end()) { - if (py::isinstance(out)) { + if (py::isinstance(out) || py::isinstance(out)) { auto tuple = out.cast(); auto tuple_size = static_cast(tuple.size()); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 0b8bc1be88..0738a95104 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -117,7 +117,6 @@ class PynativeExecutor : public std::enable_shared_from_this { // replace for grad graph ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple); ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); - void SaveForwardResult(const CNodePtr &cnode, const py::object &out); void GenTupleMap(const ValueTuplePtr &tuple, std::map *t_map); void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);