!2798 Decouple ParamValue from python
Merge pull request !2798 from hewei/decouple_param_valuepull/2798/MERGE
commit
5f468b65b6
@ -0,0 +1,95 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
|
||||
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
|
||||
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
class ParamValue {
|
||||
public:
|
||||
ParamValue() {}
|
||||
|
||||
ParamValue(const ParamValue &other) = default;
|
||||
|
||||
~ParamValue() = default;
|
||||
|
||||
tensor::MetaTensorPtr value() const { return value_; }
|
||||
void set_value(const tensor::MetaTensorPtr &value) { value_ = value; }
|
||||
|
||||
const std::string &name() const { return name_; }
|
||||
void set_name(const std::string &name) { name_ = name; }
|
||||
|
||||
const std::string &sparse_grad() const { return sparse_grad_; }
|
||||
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
|
||||
|
||||
bool requires_grad() const { return requires_grad_; }
|
||||
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
|
||||
|
||||
bool layerwise_parallel() const { return layerwise_parallel_; }
|
||||
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
|
||||
|
||||
bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
|
||||
void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; }
|
||||
|
||||
// Whether the parameter clone from other parameter.
|
||||
bool cloned() const { return cloned_; }
|
||||
|
||||
// Whether the parameter is cloned.
|
||||
bool be_cloned() const { return be_cloned_; }
|
||||
|
||||
// If the parameter is cloned, generate one index per clone.
|
||||
const std::vector<int32_t> &be_cloned_index() const { return be_cloned_index_; }
|
||||
|
||||
// If the parameter clone from other parameter, it has a unique index.
|
||||
int32_t cloned_index() const { return cloned_index_; }
|
||||
|
||||
// Make a cloned parameter and update clone info.
|
||||
ParamValuePtr Clone() {
|
||||
static std::atomic<int32_t> parameter_cloned_index{1};
|
||||
int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
|
||||
auto clone = std::make_shared<ParamValue>(*this);
|
||||
clone->be_cloned_ = false;
|
||||
clone->cloned_ = true;
|
||||
clone->be_cloned_index_ = {};
|
||||
clone->cloned_index_ = index;
|
||||
this->be_cloned_ = true;
|
||||
this->be_cloned_index_.push_back(index);
|
||||
return clone;
|
||||
}
|
||||
|
||||
private:
|
||||
tensor::MetaTensorPtr value_;
|
||||
std::string name_{"Parameter"};
|
||||
std::string sparse_grad_;
|
||||
bool requires_grad_{true};
|
||||
bool layerwise_parallel_{false};
|
||||
bool has_indexed_slices_grad_{false};
|
||||
bool be_cloned_{false};
|
||||
bool cloned_{false};
|
||||
std::vector<int32_t> be_cloned_index_;
|
||||
int32_t cloned_index_{0};
|
||||
};
|
||||
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
|
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ir/param_value.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
||||
REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
|
||||
(void)py::class_<ParamValue, ParamValuePtr>(*m, "ParamValue")
|
||||
.def(py::init())
|
||||
.def("clone", &ParamValue::Clone)
|
||||
.def_property("data", &ParamValue::value, &ParamValue::set_value)
|
||||
.def_property("name", &ParamValue::name, &ParamValue::set_name)
|
||||
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
|
||||
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
|
||||
&ParamValue::set_layerwise_parallel)
|
||||
.def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad,
|
||||
&ParamValue::set_has_indexed_slices_grad)
|
||||
.def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad)
|
||||
.def(py::pickle(
|
||||
[](const ParamValue &p) { // __getstate__
|
||||
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
|
||||
p.layerwise_parallel(), p.has_indexed_slices_grad(),
|
||||
p.sparse_grad());
|
||||
},
|
||||
[](const py::tuple &t) { // __setstate__
|
||||
if (t.size() != 6) {
|
||||
std::runtime_error("Invalid state for ParamValue!");
|
||||
}
|
||||
ParamValuePtr p = std::make_shared<ParamValue>();
|
||||
p->set_value(t[0].cast<tensor::TensorPtr>());
|
||||
p->set_name(t[1].cast<std::string>());
|
||||
p->set_requires_grad(t[2].cast<bool>());
|
||||
p->set_layerwise_parallel(t[3].cast<bool>());
|
||||
p->set_has_indexed_slices_grad(t[4].cast<bool>());
|
||||
p->set_sparse_grad(t[5].cast<std::string>());
|
||||
return p;
|
||||
}));
|
||||
}));
|
||||
} // namespace mindspore
|
@ -1,43 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
||||
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace py = pybind11;
|
||||
|
||||
class ParamValuePy : public ParamValue {
|
||||
public:
|
||||
ParamValuePy() : value_(py::none()) {}
|
||||
explicit ParamValuePy(const py::object &value) : value_(value) {}
|
||||
~ParamValuePy() override = default;
|
||||
|
||||
py::object value() { return value_; }
|
||||
void set_value(const py::object &obj) { value_ = obj; }
|
||||
|
||||
private:
|
||||
py::object value_;
|
||||
};
|
||||
|
||||
using ParamValuePyPtr = std::shared_ptr<ParamValuePy>;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue