|
|
|
@ -15,6 +15,7 @@ limitations under the License. */
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <cstdint>
|
|
|
|
|
#include <functional>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
@ -23,13 +24,23 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
|
|
|
|
|
class Node {
|
|
|
|
|
public:
|
|
|
|
|
enum class Type { kNone = -1, kOperation, kVariable };
|
|
|
|
|
|
|
|
|
|
Node() {}
|
|
|
|
|
virtual ~Node() {}
|
|
|
|
|
Node(const std::string& name, Type type) : name_(name), type_(type) {}
|
|
|
|
|
|
|
|
|
|
virtual ~Node() {
|
|
|
|
|
for (auto& attr : attrs_) {
|
|
|
|
|
if (attr_dels_.find(attr.first) != attr_dels_.end()) {
|
|
|
|
|
attr_dels_[attr.first]();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
attr_dels_.clear();
|
|
|
|
|
attrs_.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int64_t ID() const { return id_; }
|
|
|
|
|
|
|
|
|
@ -43,17 +54,42 @@ class Node {
|
|
|
|
|
|
|
|
|
|
Type NodeType() const { return type_; }
|
|
|
|
|
|
|
|
|
|
std::vector<Node *> inputs;
|
|
|
|
|
std::vector<Node *> outputs;
|
|
|
|
|
template <typename AttrType>
|
|
|
|
|
void Set(const std::string& name, AttrType attr) {
|
|
|
|
|
attrs_[name] = attr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename AttrType>
|
|
|
|
|
void Set(const std::string& name, AttrType* attr,
|
|
|
|
|
std::function<void(void)> attr_del) {
|
|
|
|
|
attrs_[name] = attr;
|
|
|
|
|
attr_dels_[name] = attr_del;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<Node*> inputs;
|
|
|
|
|
std::vector<Node*> outputs;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::map<std::string, std::vector<boost::any>> attrs_;
|
|
|
|
|
std::map<std::string, boost::any> attrs_;
|
|
|
|
|
std::map<std::string, std::function<void(void)>> attr_dels_;
|
|
|
|
|
int64_t id_ = 0;
|
|
|
|
|
std::string name_;
|
|
|
|
|
Type type_;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
DISABLE_COPY_AND_ASSIGN(Node);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Variable : public Node {
|
|
|
|
|
public:
|
|
|
|
|
explicit Variable(const std::string& name) : Node(name, Type::kVariable) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Operation : public Node {
|
|
|
|
|
public:
|
|
|
|
|
explicit Operation(const std::string& name) : Node(name, Type::kOperation) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace ir
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|