parent
879a519136
commit
e6f82af849
@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind_api/ir/cell_py.h"
|
||||
#include <string>
|
||||
|
||||
#include "pybind_api/api_register.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "pipeline/jit/parse/python_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &obj) {
|
||||
std::string attr_name = name;
|
||||
ValuePtr converted_ret = nullptr;
|
||||
if (py::isinstance<py::module>(obj)) {
|
||||
MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module";
|
||||
}
|
||||
bool converted = parse::ConvertData(obj, &converted_ret, true);
|
||||
if (!converted) {
|
||||
MS_LOG(DEBUG) << "Attribute convert error with type: " << std::string(py::str(obj));
|
||||
} else {
|
||||
MS_LOG(DEBUG) << cell->ToString() << " add attr " << attr_name << converted_ret->ToString();
|
||||
cell->AddAttr(attr_name, converted_ret);
|
||||
}
|
||||
}
|
||||
// Define python 'Cell' class.
|
||||
REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) {
|
||||
(void)py::class_<Cell, std::shared_ptr<Cell>>(*m, "Cell_")
|
||||
.def(py::init<std::string &>())
|
||||
.def("__str__", &Cell::ToString)
|
||||
.def("_add_attr", &CellPy::AddAttr, "Add Cell attr.")
|
||||
.def("_del_attr", &Cell::DelAttr, "Delete Cell attr.")
|
||||
.def(
|
||||
"construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; },
|
||||
"construct");
|
||||
}));
|
||||
} // namespace mindspore
|
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_UTILS_CELL_PY_H_
|
||||
#define MINDSPORE_CCSRC_UTILS_CELL_PY_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/numpy.h"
|
||||
|
||||
#include "ir/cell.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
// brief mindspore namespace.
|
||||
//
|
||||
// mindspore namespace is the top level namespace of Mindsporeession project.
|
||||
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
|
||||
namespace mindspore {
|
||||
|
||||
// Cell python wrapper and adapter class.
|
||||
class CellPy {
|
||||
public:
|
||||
static void AddAttr(CellPtr cell, const std::string &name, const py::object &obj);
|
||||
};
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_UTILS_CELL_PY_H_
|
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ir/cell.h"
|
||||
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
|
||||
namespace mindspore {
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
|
||||
abstract::AbstractBasePtr Cell::ToAbstract() {
|
||||
/*
|
||||
std::vector<abstract::AbstractAttribute> abs_attrs;
|
||||
std::transform(attrs_.begin(), attrs_.end(), std::back_inserter(abs_attrs),
|
||||
[](std::pair<std::string, ValuePtr> attr) -> abstract::AbstractAttribute {
|
||||
return std::make_pair(attr.first, attr.second->ToAbstract());
|
||||
});
|
||||
auto abs = std::make_shared<abstract::AbstractCell>(shared_from_base<Named>(), abs_attrs);
|
||||
abs->set_value(shared_from_base<Value>());
|
||||
return abs;
|
||||
*/
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool Cell::operator==(const Value &other) const {
|
||||
if (other.isa<Cell>()) {
|
||||
auto other_prim = static_cast<const Cell &>(other);
|
||||
return *this == other_prim;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool Cell::operator==(const Cell &other) const {
|
||||
if (name() != other.name()) {
|
||||
return false;
|
||||
}
|
||||
if (attrs_.size() != other.attrs_.size()) {
|
||||
return false;
|
||||
}
|
||||
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
|
||||
if (item.second == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto iter = other.attrs_.find(item.first);
|
||||
if (iter == other.attrs_.end()) {
|
||||
return false;
|
||||
}
|
||||
return *item.second == *iter->second;
|
||||
});
|
||||
return all;
|
||||
}
|
||||
|
||||
std::string Cell::GetAttrString() const {
|
||||
std::ostringstream buffer;
|
||||
bool begin = true;
|
||||
buffer << "{" << std::endl;
|
||||
for (auto &attr : attrs_) {
|
||||
if (!begin) {
|
||||
buffer << ", " << std::endl;
|
||||
} else {
|
||||
begin = false;
|
||||
}
|
||||
buffer << attr.first << ":" << attr.second->ToString();
|
||||
}
|
||||
buffer << "}";
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
std::string Cell::ToString() const {
|
||||
std::ostringstream buffer;
|
||||
buffer << "Cell " << name();
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
void Cell::DelAttr(const std::string &name) { attrs_.erase(name); }
|
||||
} // namespace mindspore
|
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_IR_CELL_H_
|
||||
#define MINDSPORE_CCSRC_IR_CELL_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/misc.h"
|
||||
|
||||
namespace mindspore {
|
||||
using abstract::AbstractBasePtr;
|
||||
using abstract::AbstractBasePtrList;
|
||||
// value for Cell
|
||||
|
||||
class Cell : public Named {
|
||||
public:
|
||||
explicit Cell(const std::string &name) : Named(name) {}
|
||||
MS_DECLARE_PARENT(Cell, Named);
|
||||
abstract::AbstractBasePtr ToAbstract() override;
|
||||
std::string ToString() const override;
|
||||
std::string GetAttrString() const;
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs_input) { attrs_ = attrs_input; }
|
||||
|
||||
void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
|
||||
void DelAttr(const std::string &name);
|
||||
ValuePtr GetAttr(const std::string &attr_name) const {
|
||||
auto iter = attrs_.find(attr_name);
|
||||
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
|
||||
bool HasAttr(const std::string &attr_name) const {
|
||||
auto iter = attrs_.find(attr_name);
|
||||
return !(iter == attrs_.cend());
|
||||
}
|
||||
|
||||
bool operator==(const Value &other) const override;
|
||||
bool operator==(const Cell &other) const;
|
||||
~Cell() override = default;
|
||||
const bool parse_info_ = true;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, ValuePtr> attrs_;
|
||||
};
|
||||
|
||||
using CellPtr = std::shared_ptr<Cell>;
|
||||
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_IR_CELL_H_
|
Loading…
Reference in new issue