From 4bede54fc1e4203fb4b8de59c4d4bfd879487a13 Mon Sep 17 00:00:00 2001 From: kpy Date: Tue, 2 Jun 2020 17:31:11 +0800 Subject: [PATCH] fix pynative refactor bug --- mindspore/ccsrc/pynative/pynative_execute.cc | 66 +++++++++++++------- mindspore/ccsrc/pynative/pynative_execute.h | 2 +- 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 50160f48ce..4f2a961394 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -76,6 +76,9 @@ std::string GetId(const py::object &obj) { std::string prefix = ""; if (py::isinstance(to_process)) { auto p_list = py::cast(to_process); + if (p_list.size() == 0) { + return "empty"; + } to_process = p_list[0]; prefix = "tuple:"; if (!py::isinstance(to_process)) { @@ -101,14 +104,24 @@ std::string GetId(const py::object &obj) { return py::cast(ret); } -py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args) { +py::object GetTupleObj(const py::object &obj) { + py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); + py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj); + return obj_tuple; +} + +void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) { + auto &py_args = *out_args; + for (size_t i = 0; i < args.size(); ++i) { + py_args[i] = GetTupleObj(args[i]); + } auto signature = prim->signatures(); std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { - return py_args; + return; } std::map> type_indexs; for (size_t i = 0; i < dtypes.size(); ++i) { @@ -134,22 +147,19 @@ py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args) { } (void)dst_type.insert(std::make_pair(type, m_index)); } - py::list py_inputs(py_args.size()); for (size_t i = 0; i < py_args.size(); ++i) { auto it = dst_type.find(dtypes[i]); if (it != dst_type.end() && it->second != i && (py::isinstance(py_args[i]) || py::isinstance(py_args[i]))) { auto tensor_ptr = py::cast(py_args[it->second]); if (py::isinstance(py_args[i])) { - py_inputs[i] = std::make_shared(py::cast(py_args[i]), tensor_ptr->Dtype()); + py_args[i] = std::make_shared(py::cast(py_args[i]), tensor_ptr->Dtype()); } else { - py_inputs[i] = std::make_shared(py::cast(py_args[i]), tensor_ptr->Dtype()); + py_args[i] = std::make_shared(py::cast(py_args[i]), tensor_ptr->Dtype()); } continue; } - py_inputs[i] = py_args[i]; } - return py_inputs; } void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) { @@ -167,12 +177,6 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn op_exec_info->abstract = infer_res; } -py::object GetTupleObj(const py::object &obj) { - py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); - py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj); - return obj_tuple; -} - OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { if (args.size() != PY_ARGS_NUM) { MS_LOG(ERROR) << "Four args are needed by RunOp"; @@ -186,19 +190,18 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { if (pyobj == nullptr) { MS_LOG(EXCEPTION) << "pyobj is empty"; } - py::list py_args = ConvertInputs(prim, args[PY_INPUTS]); + + py::list a = args[PY_INPUTS]; + size_t input_num = a.size(); + op_exec_info->op_inputs = py::tuple(input_num); + + ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs); // use python infer method if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { - PynativeInfer(prim, py_args, op_exec_info.get()); + PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get()); } op_exec_info->py_primitive = prim; op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); - size_t input_num = py_args.size(); - op_exec_info->op_inputs = py::tuple(input_num); - for (size_t i = 0; i < input_num; ++i) { - auto obj = py_args[i]; - op_exec_info->op_inputs[i] = GetTupleObj(obj); - } op_exec_info->inputs_mask = args[PY_INPUT_MASK]; if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) { MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask"; @@ -663,8 +666,25 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c cell_graph_map_[cell_id] = curr_g_; auto out_id = GetId(out); if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) { - MS_LOG(ERROR) << "graph has no this out: " << out_id; - return; + // cell construct return x, y + if (py::isinstance(out)) { + std::vector args; + args.push_back(NewValueNode(prim::kPrimMakeTuple)); + + auto tuple = out.cast(); + MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size(); + auto tuple_size = static_cast(tuple.size()); + auto cnode = curr_g_->NewCNode(args); + for (int i = 0; i < tuple_size; i++) { + args.push_back(GetInput(tuple[i], py::object())); + set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i); + } + cnode->set_inputs(args); + set_obj_node_map(curr_g_, out_id, cnode); + } else { + MS_LOG(ERROR) << "Graph has no this out: " << out_id; + return; + } } auto output_node = GetObjNode(out); diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h index 5ed6c1ab80..d9a236c738 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pynative/pynative_execute.h @@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat py::tuple RunOp(const py::args &args); -py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args); +void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args); void ClearPyNativeSession();