!1798 fix pynative bug

Merge pull request !1798 from flywind/fix_pynative_bug
pull/1798/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ee6252627f

@ -76,6 +76,9 @@ std::string GetId(const py::object &obj) {
std::string prefix = "";
if (py::isinstance<py::tuple>(to_process)) {
auto p_list = py::cast<py::tuple>(to_process);
if (p_list.size() == 0) {
return "empty";
}
to_process = p_list[0];
prefix = "tuple:";
if (!py::isinstance<tensor::Tensor>(to_process)) {
@ -101,14 +104,24 @@ std::string GetId(const py::object &obj) {
return py::cast<std::string>(ret);
}
py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args) {
py::object GetTupleObj(const py::object &obj) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
return obj_tuple;
}
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
auto &py_args = *out_args;
for (size_t i = 0; i < args.size(); ++i) {
py_args[i] = GetTupleObj(args[i]);
}
auto signature = prim->signatures();
std::vector<SignatureEnumDType> dtypes;
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
[](const Signature &sig) { return sig.dtype; });
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
return py_args;
return;
}
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
for (size_t i = 0; i < dtypes.size(); ++i) {
@ -134,22 +147,19 @@ py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args) {
}
(void)dst_type.insert(std::make_pair(type, m_index));
}
py::list py_inputs(py_args.size());
for (size_t i = 0; i < py_args.size(); ++i) {
auto it = dst_type.find(dtypes[i]);
if (it != dst_type.end() && it->second != i &&
(py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
if (py::isinstance<py::int_>(py_args[i])) {
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
} else {
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
}
continue;
}
py_inputs[i] = py_args[i];
}
return py_inputs;
}
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
@ -167,12 +177,6 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
op_exec_info->abstract = infer_res;
}
py::object GetTupleObj(const py::object &obj) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
return obj_tuple;
}
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (args.size() != PY_ARGS_NUM) {
MS_LOG(ERROR) << "Four args are needed by RunOp";
@ -186,19 +190,18 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
if (pyobj == nullptr) {
MS_LOG(EXCEPTION) << "pyobj is empty";
}
py::list py_args = ConvertInputs(prim, args[PY_INPUTS]);
py::list a = args[PY_INPUTS];
size_t input_num = a.size();
op_exec_info->op_inputs = py::tuple(input_num);
ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
// use python infer method
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
PynativeInfer(prim, py_args, op_exec_info.get());
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
}
op_exec_info->py_primitive = prim;
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
size_t input_num = py_args.size();
op_exec_info->op_inputs = py::tuple(input_num);
for (size_t i = 0; i < input_num; ++i) {
auto obj = py_args[i];
op_exec_info->op_inputs[i] = GetTupleObj(obj);
}
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
@ -663,8 +666,25 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
cell_graph_map_[cell_id] = curr_g_;
auto out_id = GetId(out);
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
MS_LOG(ERROR) << "graph has no this out: " << out_id;
return;
// cell construct return x, y
if (py::isinstance<py::tuple>(out)) {
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));
auto tuple = out.cast<py::tuple>();
MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size();
auto tuple_size = static_cast<int>(tuple.size());
auto cnode = curr_g_->NewCNode(args);
for (int i = 0; i < tuple_size; i++) {
args.push_back(GetInput(tuple[i], py::object()));
set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
}
cnode->set_inputs(args);
set_obj_node_map(curr_g_, out_id, cnode);
} else {
MS_LOG(ERROR) << "Graph has no this out: " << out_id;
return;
}
}
auto output_node = GetObjNode(out);

@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple RunOp(const py::args &args);
py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args);
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
void ClearPyNativeSession();

Loading…
Cancel
Save