diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index 6ec27c2567..3526e47f96 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -30,17 +30,21 @@ #include "pybind_api/export_flags.h" namespace mindspore { +static ValuePtr PyArgToValue(const py::object &arg) { + if (py::isinstance(arg) && + py::cast(arg) == SignatureEnumKind::kKindEmptyDefaultValue) { + return nullptr; + } + return parse::data_converter::PyDataToValue(arg); +} + void PrimitivePy::set_signatures( std::vector> signatures) { signatures_.clear(); for (auto &signature : signatures) { - std::string name; - SignatureEnumRW rw; - SignatureEnumKind kind; - py::object default_value; - SignatureEnumDType dtype; - std::tie(name, rw, kind, default_value, dtype) = signature; - signatures_.emplace_back(Signature(name, rw, kind, default_value, dtype)); + auto [name, rw, kind, arg_default, dtype] = signature; + auto default_value = PyArgToValue(arg_default); + signatures_.emplace_back(name, rw, kind, default_value, dtype); } set_has_signature(true); } diff --git a/mindspore/ccsrc/ir/signature.h b/mindspore/ccsrc/ir/signature.h index 48be7e0f31..e9a5a2e1ca 100644 --- a/mindspore/ccsrc/ir/signature.h +++ b/mindspore/ccsrc/ir/signature.h @@ -16,14 +16,11 @@ #ifndef MINDSPORE_CCSRC_IR_SIGNATURE_H_ #define MINDSPORE_CCSRC_IR_SIGNATURE_H_ + #include #include - -#include "pybind11/operators.h" #include "ir/value.h" -namespace py = pybind11; - namespace mindspore { // Input signature, support type enum SignatureEnumRW { @@ -62,8 +59,10 @@ struct Signature { ValuePtr default_value; // nullptr for no default value SignatureEnumDType dtype; Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, - const py::object &arg_default, const SignatureEnumDType &arg_dtype); - Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind); + const ValuePtr &arg_default, const SignatureEnumDType &arg_dtype) + : name(arg_name), rw(rw_tag), kind(arg_kind), default_value(arg_default), dtype(arg_dtype) {} + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) + : Signature(arg_name, rw_tag, arg_kind, nullptr, SignatureEnumDType::kDTypeEmptyDefaultValue) {} }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/signature.cc b/mindspore/ccsrc/ir/signature_py.cc similarity index 77% rename from mindspore/ccsrc/ir/signature.cc rename to mindspore/ccsrc/ir/signature_py.cc index 8f312d5b98..2b01b3e579 100644 --- a/mindspore/ccsrc/ir/signature.cc +++ b/mindspore/ccsrc/ir/signature_py.cc @@ -15,30 +15,14 @@ */ #include "ir/signature.h" - #include "pybind11/operators.h" #include "pybind_api/api_register.h" #include "pipeline/parse/data_converter.h" -namespace mindspore { -Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, - const py::object &arg_default, const SignatureEnumDType &arg_dtype) - : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { - if (py::isinstance(arg_default) && - py::cast(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { - default_value = nullptr; - } else { - default_value = parse::data_converter::PyDataToValue(arg_default); - } -} - -Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) - : name(arg_name), - rw(rw_tag), - kind(arg_kind), - default_value(nullptr), - dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} +namespace py = pybind11; +namespace mindspore { +// Bind SignatureEnumRW as a python class. REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { (void)py::enum_(*m, "signature_rw", py::arithmetic()) .value("RW_READ", SignatureEnumRW::kRWRead)