|
|
|
@ -13,30 +13,30 @@
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
#include "ir/param_value.h"
|
|
|
|
|
#include "ir/param_info.h"
|
|
|
|
|
#include "pybind11/pybind11.h"
|
|
|
|
|
#include "pybind_api/api_register.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace py = pybind11;
|
|
|
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
|
|
|
|
|
(void)py::class_<ParamValue, ParamValuePtr>(*m, "ParamValue")
|
|
|
|
|
REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
|
|
|
|
|
(void)py::class_<ParamInfo, ParamValuePtr>(*m, "ParamInfo")
|
|
|
|
|
.def(py::init())
|
|
|
|
|
.def("clone", &ParamValue::Clone)
|
|
|
|
|
.def_property("name", &ParamValue::name, &ParamValue::set_name)
|
|
|
|
|
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
|
|
|
|
|
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
|
|
|
|
|
&ParamValue::set_layerwise_parallel)
|
|
|
|
|
.def("clone", &ParamInfo::Clone)
|
|
|
|
|
.def_property("name", &ParamInfo::name, &ParamInfo::set_name)
|
|
|
|
|
.def_property("requires_grad", &ParamInfo::requires_grad, &ParamInfo::set_requires_grad)
|
|
|
|
|
.def_property("layerwise_parallel", &ParamInfo::layerwise_parallel,
|
|
|
|
|
&ParamInfo::set_layerwise_parallel)
|
|
|
|
|
.def(py::pickle(
|
|
|
|
|
[](const ParamValue &p) { // __getstate__
|
|
|
|
|
[](const ParamInfo &p) { // __getstate__
|
|
|
|
|
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());
|
|
|
|
|
},
|
|
|
|
|
[](const py::tuple &t) { // __setstate__
|
|
|
|
|
if (t.size() != 6) {
|
|
|
|
|
std::runtime_error("Invalid state for ParamValue!");
|
|
|
|
|
std::runtime_error("Invalid state for ParamInfo!");
|
|
|
|
|
}
|
|
|
|
|
ParamValuePtr p = std::make_shared<ParamValue>();
|
|
|
|
|
ParamValuePtr p = std::make_shared<ParamInfo>();
|
|
|
|
|
p->set_name(t[1].cast<std::string>());
|
|
|
|
|
p->set_requires_grad(t[2].cast<bool>());
|
|
|
|
|
p->set_layerwise_parallel(t[3].cast<bool>());
|