|
|
|
@ -20,11 +20,13 @@
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
|
|
#include "utils/any.h"
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "utils/context/ms_context.h"
|
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
#include "operator/composite/do_signature.h"
|
|
|
|
|
#include "pipeline/parse/data_converter.h"
|
|
|
|
|
#include "pipeline/static_analysis/prim.h"
|
|
|
|
|
#include "session/session_factory.h"
|
|
|
|
@ -50,6 +52,57 @@ inline ValuePtr PyAttrValue(const py::object& obj) {
|
|
|
|
|
return converted_ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) {
|
|
|
|
|
auto signature = prim->signatures();
|
|
|
|
|
std::vector<SignatureEnumDType> dtypes;
|
|
|
|
|
(void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes),
|
|
|
|
|
[](const Signature& sig) { return sig.dtype; });
|
|
|
|
|
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
|
|
|
|
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
|
|
|
|
return py_args;
|
|
|
|
|
}
|
|
|
|
|
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
|
|
|
|
for (size_t i = 0; i < dtypes.size(); ++i) {
|
|
|
|
|
auto it = type_indexs.find(dtypes[i]);
|
|
|
|
|
if (it == type_indexs.end()) {
|
|
|
|
|
(void)type_indexs.insert(std::make_pair(dtypes[i], std::vector<size_t>{i}));
|
|
|
|
|
} else {
|
|
|
|
|
it->second.push_back(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
std::map<SignatureEnumDType, size_t> dst_type;
|
|
|
|
|
for (auto it = type_indexs.begin(); it != type_indexs.end(); (void)++it) {
|
|
|
|
|
auto type = it->first;
|
|
|
|
|
auto indexs = it->second;
|
|
|
|
|
if (indexs.size() < 2) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
size_t m_index = indexs[0];
|
|
|
|
|
for (size_t i = 1; i < indexs.size(); ++i) {
|
|
|
|
|
if (py::isinstance<tensor::Tensor>(py_args[indexs[i]])) {
|
|
|
|
|
m_index = indexs[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
(void)dst_type.insert(std::make_pair(type, m_index));
|
|
|
|
|
}
|
|
|
|
|
py::tuple py_inputs(py_args.size());
|
|
|
|
|
for (size_t i = 0; i < py_args.size(); ++i) {
|
|
|
|
|
auto it = dst_type.find(dtypes[i]);
|
|
|
|
|
if (it != dst_type.end() && it->second != i &&
|
|
|
|
|
(py::isinstance<py::int_>(py_args[i]) || py::isinstance<py::float_>(py_args[i]))) {
|
|
|
|
|
auto tensor_ptr = py::cast<tensor::TensorPtr>(py_args[it->second]);
|
|
|
|
|
if (py::isinstance<py::int_>(py_args[i])) {
|
|
|
|
|
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::int_>(py_args[i]), tensor_ptr->Dtype());
|
|
|
|
|
} else {
|
|
|
|
|
py_inputs[i] = std::make_shared<tensor::Tensor>(py::cast<py::float_>(py_args[i]), tensor_ptr->Dtype());
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
py_inputs[i] = py_args[i];
|
|
|
|
|
}
|
|
|
|
|
return py_inputs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) {
|
|
|
|
|
size_t size = py_args.size();
|
|
|
|
|
AbstractBasePtrList args_spec_list;
|
|
|
|
@ -73,30 +126,22 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) {
|
|
|
|
|
auto op_exec_info = std::make_shared<OpExecInfo>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_exec_info);
|
|
|
|
|
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
|
|
|
|
|
if (py::isinstance<py::none>(args[PY_PRIM])) {
|
|
|
|
|
py::module ops_mod = py::module::import("mindspore.ops.operations");
|
|
|
|
|
py::object py_primitive = ops_mod.attr(op_exec_info->op_name.c_str())();
|
|
|
|
|
op_exec_info->py_primitive = py::cast<PrimitivePyPtr>(py_primitive);
|
|
|
|
|
py::dict none_attrs = py::dict();
|
|
|
|
|
op_exec_info->op_attrs = none_attrs;
|
|
|
|
|
} else {
|
|
|
|
|
PrimitivePyPtr prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
|
|
|
|
auto pyobj = prim->GetPyObj();
|
|
|
|
|
if (pyobj == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "pyobj is empty";
|
|
|
|
|
}
|
|
|
|
|
py::tuple py_args = args[PY_INPUTS];
|
|
|
|
|
// use python infer method
|
|
|
|
|
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
|
|
|
|
PynativeInfer(prim, py_args, op_exec_info.get());
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->py_primitive = prim;
|
|
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
|
|
|
|
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
|
|
|
|
auto pyobj = prim->GetPyObj();
|
|
|
|
|
if (pyobj == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "pyobj is empty";
|
|
|
|
|
}
|
|
|
|
|
py::tuple py_args = ConvertInputs(prim, args[PY_INPUTS]);
|
|
|
|
|
// use python infer method
|
|
|
|
|
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
|
|
|
|
PynativeInfer(prim, py_args, op_exec_info.get());
|
|
|
|
|
}
|
|
|
|
|
op_exec_info->op_inputs = args[PY_INPUTS];
|
|
|
|
|
op_exec_info->py_primitive = prim;
|
|
|
|
|
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
|
|
|
|
op_exec_info->op_inputs = py_args;
|
|
|
|
|
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
|
|
|
|
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "" << op_exec_info->op_name << " op_inputs size not equal op_mask";
|
|
|
|
|
MS_LOG(ERROR) << "op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return op_exec_info;
|
|
|
|
|