|
|
|
@ -76,6 +76,9 @@ std::string GetId(const py::object &obj) {
|
|
|
|
|
std::string prefix = "";
|
|
|
|
|
if (py::isinstance<py::tuple>(to_process)) {
|
|
|
|
|
auto p_list = py::cast<py::tuple>(to_process);
|
|
|
|
|
if (p_list.size() == 0) {
|
|
|
|
|
return "empty";
|
|
|
|
|
}
|
|
|
|
|
to_process = p_list[0];
|
|
|
|
|
prefix = "tuple:";
|
|
|
|
|
if (!py::isinstance<tensor::Tensor>(to_process)) {
|
|
|
|
@ -101,14 +104,24 @@ std::string GetId(const py::object &obj) {
|
|
|
|
|
return py::cast<std::string>(ret);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args) {
|
|
|
|
|
py::object GetTupleObj(const py::object &obj) {
|
|
|
|
|
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
|
|
|
|
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
|
|
|
|
|
return obj_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
|
|
|
|
|
auto &py_args = *out_args;
|
|
|
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
|
|
|
py_args[i] = GetTupleObj(args[i]);
|
|
|
|
|
}
|
|
|
|
|
auto signature = prim->signatures();
|
|
|
|
|
std::vector<SignatureEnumDType> dtypes;
|
|
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
|
|
|
|
[](const Signature &sig) { return sig.dtype; });
|
|
|
|
|
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
|
|
|
|
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
|
|
|
|
return py_args;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
|
|
|
|
for (size_t i = 0; i < dtypes.size(); ++i) {
|
|
|
|
@ -134,22 +147,19 @@ py::list ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args) {
|
|
|
|
|
}
|
|
|
|
|
(void)dst_type.insert(std::make_pair(type, m_index));
|
|
|
|
|
}
|
|
|
|
|
py::list py_inputs(py_args.size());
|
|
|
|
|
for (size_t i = 0; i < py_args.size(); ++i) {
|
|
|
|
|
auto it = dst_type.find(dtypes[i]);
|
|
|
|
|
if (it != dst_type.end() && it->second != i &&
|
|
|
|
|
(py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
|
|
|
|
|
auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
|
|
|
|
|
if (py::isinstance<py::int_>(py_args[i])) {
|
|
|
|
|
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
|
|
|
|
|
py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
|
|
|
|
|
} else {
|
|
|
|
|
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
|
|
|
|
|
py_args[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
py_inputs[i] = py_args[i];
|
|
|
|
|
}
|
|
|
|
|
return py_inputs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
|
|
|
|
@ -167,12 +177,6 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
|
|
|
|
|
op_exec_info->abstract = infer_res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::object GetTupleObj(const py::object &obj) {
|
|
|
|
|
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
|
|
|
|
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
|
|
|
|
|
return obj_tuple;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
if (args.size() != PY_ARGS_NUM) {
|
|
|
|
|
MS_LOG(ERROR) << "Four args are needed by RunOp";
|
|
|
|
@ -186,19 +190,18 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|
|
|
|
if (pyobj == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "pyobj is empty";
|
|
|
|
|
}
|
|
|
|
|
py::list py_args = ConvertInputs(prim, args[PY_INPUTS]);
|
|
|
|
|
|
|
|
|
|
py::list a = args[PY_INPUTS];
|
|
|
|
|
size_t input_num = a.size();
|
|
|
|
|
op_exec_info->op_inputs = py::tuple(input_num);
|
|
|
|
|
|
|
|
|
|
ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
|
|
|
|
|
// use python infer method
|
|
|
|
|
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
|
|
|
|
PynativeInfer(prim, py_args, op_exec_info.get());
|
|
|
|
|
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->py_primitive = prim;
|
|
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
|
|
|
|
size_t input_num = py_args.size();
|
|
|
|
|
op_exec_info->op_inputs = py::tuple(input_num);
|
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
|
auto obj = py_args[i];
|
|
|
|
|
op_exec_info->op_inputs[i] = GetTupleObj(obj);
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
|
|
|
|
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
|
|
|
@ -663,8 +666,25 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
|
|
|
|
cell_graph_map_[cell_id] = curr_g_;
|
|
|
|
|
auto out_id = GetId(out);
|
|
|
|
|
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
|
|
|
|
|
MS_LOG(ERROR) << "graph has no this out: " << out_id;
|
|
|
|
|
return;
|
|
|
|
|
// cell construct return x, y
|
|
|
|
|
if (py::isinstance<py::tuple>(out)) {
|
|
|
|
|
std::vector<AnfNodePtr> args;
|
|
|
|
|
args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
|
|
|
|
|
|
|
|
|
auto tuple = out.cast<py::tuple>();
|
|
|
|
|
MS_LOG(DEBUG) << "End graph start tuple size" << tuple.size();
|
|
|
|
|
auto tuple_size = static_cast<int>(tuple.size());
|
|
|
|
|
auto cnode = curr_g_->NewCNode(args);
|
|
|
|
|
for (int i = 0; i < tuple_size; i++) {
|
|
|
|
|
args.push_back(GetInput(tuple[i], py::object()));
|
|
|
|
|
set_obj_node_map(curr_g_, GetId(tuple[i]), cnode, i);
|
|
|
|
|
}
|
|
|
|
|
cnode->set_inputs(args);
|
|
|
|
|
set_obj_node_map(curr_g_, out_id, cnode);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "Graph has no this out: " << out_id;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto output_node = GetObjNode(out);
|
|
|
|
|