!2798 Decouple ParamValue from python
Merge pull request !2798 from hewei/decouple_param_valuepull/2798/MERGE
commit
5f468b65b6
@ -0,0 +1,95 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
|
||||||
|
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
|
||||||
|
class ParamValue {
|
||||||
|
public:
|
||||||
|
ParamValue() {}
|
||||||
|
|
||||||
|
ParamValue(const ParamValue &other) = default;
|
||||||
|
|
||||||
|
~ParamValue() = default;
|
||||||
|
|
||||||
|
tensor::MetaTensorPtr value() const { return value_; }
|
||||||
|
void set_value(const tensor::MetaTensorPtr &value) { value_ = value; }
|
||||||
|
|
||||||
|
const std::string &name() const { return name_; }
|
||||||
|
void set_name(const std::string &name) { name_ = name; }
|
||||||
|
|
||||||
|
const std::string &sparse_grad() const { return sparse_grad_; }
|
||||||
|
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
|
||||||
|
|
||||||
|
bool requires_grad() const { return requires_grad_; }
|
||||||
|
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
|
||||||
|
|
||||||
|
bool layerwise_parallel() const { return layerwise_parallel_; }
|
||||||
|
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
|
||||||
|
|
||||||
|
bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
|
||||||
|
void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; }
|
||||||
|
|
||||||
|
// Whether the parameter clone from other parameter.
|
||||||
|
bool cloned() const { return cloned_; }
|
||||||
|
|
||||||
|
// Whether the parameter is cloned.
|
||||||
|
bool be_cloned() const { return be_cloned_; }
|
||||||
|
|
||||||
|
// If the parameter is cloned, generate one index per clone.
|
||||||
|
const std::vector<int32_t> &be_cloned_index() const { return be_cloned_index_; }
|
||||||
|
|
||||||
|
// If the parameter clone from other parameter, it has a unique index.
|
||||||
|
int32_t cloned_index() const { return cloned_index_; }
|
||||||
|
|
||||||
|
// Make a cloned parameter and update clone info.
|
||||||
|
ParamValuePtr Clone() {
|
||||||
|
static std::atomic<int32_t> parameter_cloned_index{1};
|
||||||
|
int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
|
||||||
|
auto clone = std::make_shared<ParamValue>(*this);
|
||||||
|
clone->be_cloned_ = false;
|
||||||
|
clone->cloned_ = true;
|
||||||
|
clone->be_cloned_index_ = {};
|
||||||
|
clone->cloned_index_ = index;
|
||||||
|
this->be_cloned_ = true;
|
||||||
|
this->be_cloned_index_.push_back(index);
|
||||||
|
return clone;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
tensor::MetaTensorPtr value_;
|
||||||
|
std::string name_{"Parameter"};
|
||||||
|
std::string sparse_grad_;
|
||||||
|
bool requires_grad_{true};
|
||||||
|
bool layerwise_parallel_{false};
|
||||||
|
bool has_indexed_slices_grad_{false};
|
||||||
|
bool be_cloned_{false};
|
||||||
|
bool cloned_{false};
|
||||||
|
std::vector<int32_t> be_cloned_index_;
|
||||||
|
int32_t cloned_index_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
|
@ -0,0 +1,55 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "ir/param_value.h"
|
||||||
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind_api/api_register.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
|
||||||
|
(void)py::class_<ParamValue, ParamValuePtr>(*m, "ParamValue")
|
||||||
|
.def(py::init())
|
||||||
|
.def("clone", &ParamValue::Clone)
|
||||||
|
.def_property("data", &ParamValue::value, &ParamValue::set_value)
|
||||||
|
.def_property("name", &ParamValue::name, &ParamValue::set_name)
|
||||||
|
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
|
||||||
|
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
|
||||||
|
&ParamValue::set_layerwise_parallel)
|
||||||
|
.def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad,
|
||||||
|
&ParamValue::set_has_indexed_slices_grad)
|
||||||
|
.def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad)
|
||||||
|
.def(py::pickle(
|
||||||
|
[](const ParamValue &p) { // __getstate__
|
||||||
|
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
|
||||||
|
p.layerwise_parallel(), p.has_indexed_slices_grad(),
|
||||||
|
p.sparse_grad());
|
||||||
|
},
|
||||||
|
[](const py::tuple &t) { // __setstate__
|
||||||
|
if (t.size() != 6) {
|
||||||
|
std::runtime_error("Invalid state for ParamValue!");
|
||||||
|
}
|
||||||
|
ParamValuePtr p = std::make_shared<ParamValue>();
|
||||||
|
p->set_value(t[0].cast<tensor::TensorPtr>());
|
||||||
|
p->set_name(t[1].cast<std::string>());
|
||||||
|
p->set_requires_grad(t[2].cast<bool>());
|
||||||
|
p->set_layerwise_parallel(t[3].cast<bool>());
|
||||||
|
p->set_has_indexed_slices_grad(t[4].cast<bool>());
|
||||||
|
p->set_sparse_grad(t[5].cast<std::string>());
|
||||||
|
return p;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
|
} // namespace mindspore
|
@ -1,43 +0,0 @@
|
|||||||
/**
|
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
|
||||||
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "ir/anf.h"
|
|
||||||
#include "pybind11/pybind11.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
|
||||||
namespace py = pybind11;
|
|
||||||
|
|
||||||
class ParamValuePy : public ParamValue {
|
|
||||||
public:
|
|
||||||
ParamValuePy() : value_(py::none()) {}
|
|
||||||
explicit ParamValuePy(const py::object &value) : value_(value) {}
|
|
||||||
~ParamValuePy() override = default;
|
|
||||||
|
|
||||||
py::object value() { return value_; }
|
|
||||||
void set_value(const py::object &obj) { value_ = obj; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
py::object value_;
|
|
||||||
};
|
|
||||||
|
|
||||||
using ParamValuePyPtr = std::shared_ptr<ParamValuePy>;
|
|
||||||
} // namespace mindspore
|
|
||||||
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue