diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 4730404101..1ea1c059a8 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -767,9 +767,14 @@ void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr if (iter != op_forward_map_.end()) { return; } - op_forward_map_[id] = value; + auto tuple_info_iter = obj_to_forward_id_tuple_info_.find(id); + ValuePtr temp_value = value; + if (tuple_info_iter != obj_to_forward_id_tuple_info_.end()) { + temp_value = tuple_info_iter->second; + } + op_forward_map_[id] = temp_value; MS_LOG(DEBUG) << "Save op forward value: " - << "(" << id << "), " << value; + << "(" << id << "), " << temp_value; } void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { @@ -799,6 +804,14 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN cnode->set_forward(value, op_id); ++op_id_map_[id]; auto out_id = GetId(out_real); + if (py::isinstance(out_real)) { + auto tuple_item = py::cast(out_real); + for (size_t i = 0; i < tuple_item.size(); i++) { + auto tuple_item_id = GetId(tuple_item[i]); + obj_to_forward_id_[tuple_item_id] = op_id; + } + obj_to_forward_id_tuple_info_[op_id] = value; + } obj_to_forward_id_[out_id] = op_id; } } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index 708883dd2e..0e4276559a 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -155,6 +155,7 @@ class PynativeExecutor : public std::enable_shared_from_this { std::unordered_map op_forward_map_; std::unordered_map op_id_map_; std::unordered_map obj_to_forward_id_; + std::unordered_map obj_to_forward_id_tuple_info_; std::unordered_map node_abs_map_; std::unordered_map df_builder_map_; // the stack that records the context of graph created, the bottom is the top graph