diff --git a/mindspore/ccsrc/ir/dtype_extends.cc b/mindspore/ccsrc/ir/dtype_extends.cc index 732872cb4f..b41631c1ce 100644 --- a/mindspore/ccsrc/ir/dtype_extends.cc +++ b/mindspore/ccsrc/ir/dtype_extends.cc @@ -20,8 +20,6 @@ #include #include "utils/log_adapter.h" #include "pipeline/static_analysis/abstract_value.h" -#include "pybind_api/api_register.h" -#include "pybind_api/export_flags.h" namespace mindspore { TypePtr TypeAnything::DeepCopy() const { return kAnyType; } @@ -425,134 +423,6 @@ bool IsSubType(TypePtr const &t1, TypePtr const &t2) { } } -REGISTER_PYBIND_DEFINE( - typing, ([](py::module *const m) { - auto m_sub = m->def_submodule("typing", "submodule for dtype"); - py::enum_(m_sub, "TypeId"); - (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); - (void)m_sub.def("load_type", &TypeIdToType, "load type"); - (void)m_sub.def( - "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); - (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); - (void)py::class_>(m_sub, "Type") - .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) - .def("__eq__", - [](const TypePtr &t1, const TypePtr &t2) { - if (t1 != nullptr && t2 != nullptr) { - return *t1 == *t2; - } - return false; - }) - .def("__hash__", &Type::hash) - .def("__str__", &Type::ToString) - .def("__repr__", &Type::ReprString) - .def("__deepcopy__", [](const TypePtr &t, py::dict) { - if (t == nullptr) { - return static_cast(nullptr); - } - return t->DeepCopy(); - }); - (void)py::class_>(m_sub, "Number").def(py::init()); - (void)py::class_>(m_sub, "Bool") - .def(py::init()) - .def(py::pickle( - [](const Bool &) { // __getstate__ - return py::make_tuple(); - }, - [](const py::tuple &) { // __setstate__ - return std::make_shared(); - })); - (void)py::class_>(m_sub, "Int") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const Int &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - Int data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "UInt") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const UInt &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - UInt data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "Float") - .def(py::init()) - .def(py::init(), py::arg("nbits")) - .def(py::pickle( - [](const Float &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(t.nbits())); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - Float data(t[0].cast()); - return data; - })); - (void)py::class_>(m_sub, "List") - .def(py::init()) - .def(py::init>(), py::arg("elements")); - (void)py::class_>(m_sub, "Tuple") - .def(py::init()) - .def(py::init>(), py::arg("elements")); - (void)py::class_>(m_sub, "TensorType") - .def(py::init()) - .def(py::init(), py::arg("element")) - .def("element_type", &TensorType::element) - .def(py::pickle( - [](const TensorType &t) { // __getstate__ - /* Return a tuple that fully encodes the state of the object */ - return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); - }, - [](const py::tuple &t) { // __setstate__ - if (t.size() != 1) { - throw std::runtime_error("Invalid state!"); - } - /* Create a new C++ instance */ - TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); - return data; - })); - (void)py::class_>(m_sub, "IndexedSlicesType") - .def(py::init()); - (void)py::class_>(m_sub, "UndeterminedType") - .def(py::init()); - (void)py::class_>(m_sub, "Function") - .def(py::init()) - .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); - (void)py::class_>(m_sub, "Class").def(py::init()); - (void)py::class_>(m_sub, "SymbolicKeyType").def(py::init()); - (void)py::class_>(m_sub, "EnvType").def(py::init()); - (void)py::class_>(m_sub, "TypeNone").def(py::init()); - (void)py::class_>(m_sub, "TypeType").def(py::init()); - (void)py::class_>(m_sub, "String").def(py::init()); - (void)py::class_>(m_sub, "RefKeyType").def(py::init()); - (void)py::class_>(m_sub, "RefType").def(py::init()); - (void)py::class_>(m_sub, "TypeAnything").def(py::init()); - (void)py::class_>(m_sub, "Slice").def(py::init()); - (void)py::class_>(m_sub, "TypeEllipsis").def(py::init()); - })); - const TypePtr kTypeExternal = std::make_shared(); const TypePtr kTypeEnv = std::make_shared(); const TypePtr kTypeType = std::make_shared(); diff --git a/mindspore/ccsrc/ir/dtype_py.cc b/mindspore/ccsrc/ir/dtype_py.cc new file mode 100644 index 0000000000..c8b34a48e9 --- /dev/null +++ b/mindspore/ccsrc/ir/dtype_py.cc @@ -0,0 +1,155 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/dtype.h" +#include +#include +#include +#include "utils/log_adapter.h" +#include "pipeline/static_analysis/abstract_value.h" +#include "pybind_api/api_register.h" +#include "pybind_api/export_flags.h" + +namespace mindspore { +// Define python wrapper to handle data types. +REGISTER_PYBIND_DEFINE( + typing, ([](py::module *const m) { + auto m_sub = m->def_submodule("typing", "submodule for dtype"); + py::enum_(m_sub, "TypeId"); + (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); + (void)m_sub.def("load_type", &TypeIdToType, "load type"); + (void)m_sub.def( + "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); + (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); + (void)py::class_>(m_sub, "Type") + .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) + .def("__eq__", + [](const TypePtr &t1, const TypePtr &t2) { + if (t1 != nullptr && t2 != nullptr) { + return *t1 == *t2; + } + return false; + }) + .def("__hash__", &Type::hash) + .def("__str__", &Type::ToString) + .def("__repr__", &Type::ReprString) + .def("__deepcopy__", [](const TypePtr &t, py::dict) { + if (t == nullptr) { + return static_cast(nullptr); + } + return t->DeepCopy(); + }); + (void)py::class_>(m_sub, "Number").def(py::init()); + (void)py::class_>(m_sub, "Bool") + .def(py::init()) + .def(py::pickle( + [](const Bool &) { // __getstate__ + return py::make_tuple(); + }, + [](const py::tuple &) { // __setstate__ + return std::make_shared(); + })); + (void)py::class_>(m_sub, "Int") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const Int &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Int data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "UInt") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const UInt &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + UInt data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "Float") + .def(py::init()) + .def(py::init(), py::arg("nbits")) + .def(py::pickle( + [](const Float &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(t.nbits())); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + Float data(t[0].cast()); + return data; + })); + (void)py::class_>(m_sub, "List") + .def(py::init()) + .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "Tuple") + .def(py::init()) + .def(py::init>(), py::arg("elements")); + (void)py::class_>(m_sub, "TensorType") + .def(py::init()) + .def(py::init(), py::arg("element")) + .def("element_type", &TensorType::element) + .def(py::pickle( + [](const TensorType &t) { // __getstate__ + /* Return a tuple that fully encodes the state of the object */ + return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); + }, + [](const py::tuple &t) { // __setstate__ + if (t.size() != 1) { + throw std::runtime_error("Invalid state!"); + } + /* Create a new C++ instance */ + TensorType data(TypeIdToType(TypeId(static_cast(t[0].cast())))); + return data; + })); + (void)py::class_>(m_sub, "IndexedSlicesType") + .def(py::init()); + (void)py::class_>(m_sub, "UndeterminedType") + .def(py::init()); + (void)py::class_>(m_sub, "Function") + .def(py::init()) + .def(py::init, TypePtr>(), py::arg("args"), py::arg("retval")); + (void)py::class_>(m_sub, "Class").def(py::init()); + (void)py::class_>(m_sub, "SymbolicKeyType").def(py::init()); + (void)py::class_>(m_sub, "EnvType").def(py::init()); + (void)py::class_>(m_sub, "TypeNone").def(py::init()); + (void)py::class_>(m_sub, "TypeType").def(py::init()); + (void)py::class_>(m_sub, "String").def(py::init()); + (void)py::class_>(m_sub, "RefKeyType").def(py::init()); + (void)py::class_>(m_sub, "RefType").def(py::init()); + (void)py::class_>(m_sub, "TypeAnything").def(py::init()); + (void)py::class_>(m_sub, "Slice").def(py::init()); + (void)py::class_>(m_sub, "TypeEllipsis").def(py::init()); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index d7a6eb81f7..b0d0910304 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -25,7 +25,6 @@ #include "debug/trace.h" #include "ir/manager.h" #include "operator/ops.h" -#include "pybind_api/export_flags.h" #include "utils/ordered_set.h" #include "utils/convert_utils_base.h" diff --git a/mindspore/ccsrc/ir/meta_func_graph.h b/mindspore/ccsrc/ir/meta_func_graph.h index f63f812f9e..8b43bafe7f 100644 --- a/mindspore/ccsrc/ir/meta_func_graph.h +++ b/mindspore/ccsrc/ir/meta_func_graph.h @@ -26,16 +26,12 @@ #include #include -#include "pybind11/pybind11.h" - #include "ir/dtype.h" #include "ir/anf.h" #include "ir/func_graph.h" #include "ir/signature.h" #include "pipeline/static_analysis/abstract_value.h" -namespace py = pybind11; - namespace mindspore { // namespace to support intermediate representation definition // Graph generator. diff --git a/mindspore/ccsrc/ir/value_extends.cc b/mindspore/ccsrc/ir/value_extends.cc index 8eb34d0eeb..f5f9bb8f28 100644 --- a/mindspore/ccsrc/ir/value_extends.cc +++ b/mindspore/ccsrc/ir/value_extends.cc @@ -20,7 +20,6 @@ #include #include -#include "pybind_api/api_register.h" #include "pipeline/static_analysis/abstract_value.h" namespace mindspore { @@ -83,9 +82,4 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() { [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); return std::make_shared(kv); } - -REGISTER_PYBIND_DEFINE( - RefKey, ([](const py::module *m) { - (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); - })); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/value_py.cc b/mindspore/ccsrc/ir/value_py.cc new file mode 100644 index 0000000000..7207cd06d6 --- /dev/null +++ b/mindspore/ccsrc/ir/value_py.cc @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ir/value.h" +#include + +#include "pybind_api/api_register.h" +#include "pipeline/static_analysis/abstract_value.h" + +namespace mindspore { +// Define python 'RefKey' class. +REGISTER_PYBIND_DEFINE( + RefKey, ([](const py::module *m) { + (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); + })); +} // namespace mindspore