|
|
|
@ -99,13 +99,13 @@ class Graph {
|
|
|
|
|
// Create a normal variable with non-null VarDesc.
|
|
|
|
|
ir::Node *CreateVarNode(VarDesc *var_desc) {
|
|
|
|
|
PADDLE_ENFORCE(var_desc);
|
|
|
|
|
return AddNode(new ir::Node(var_desc, node_count_++));
|
|
|
|
|
return AddNode(new ir::Node(var_desc));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create a normal runnable operator with OpDesc.
|
|
|
|
|
ir::Node *CreateOpNode(OpDesc *op_desc) {
|
|
|
|
|
PADDLE_ENFORCE(op_desc);
|
|
|
|
|
return AddNode(new ir::Node(op_desc, node_count_++));
|
|
|
|
|
return AddNode(new ir::Node(op_desc));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create a control dependency var that connects 2 operations. The
|
|
|
|
@ -115,14 +115,13 @@ class Graph {
|
|
|
|
|
// TODO(panyx0718): control var name should be really unique.
|
|
|
|
|
const std::string name = string::Sprintf(
|
|
|
|
|
"%s@%llu", ir::Node::kControlDepVarName, node_set_.size());
|
|
|
|
|
return AddNode(
|
|
|
|
|
new ir::Node(name, ir::Node::Type::kVariable, node_count_++));
|
|
|
|
|
return AddNode(new ir::Node(name, ir::Node::Type::kVariable));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// A more free style way of creating a graph node. Mostly use for test
|
|
|
|
|
// or "copy" from another node. Avoid using it if possible.
|
|
|
|
|
ir::Node *CreateEmptyNode(const std::string &name, ir::Node::Type type) {
|
|
|
|
|
return AddNode(new ir::Node(name, type, node_count_++));
|
|
|
|
|
return AddNode(new ir::Node(name, type));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Clear all node information of the graph and return the ownership of the
|
|
|
|
@ -143,9 +142,13 @@ class Graph {
|
|
|
|
|
nodes_.erase(node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NOTE low performance, but simple and secure.
|
|
|
|
|
Node *RetriveNode(int id) {
|
|
|
|
|
auto it = id2node_.find(id);
|
|
|
|
|
if (it != id2node_.end()) return it->second;
|
|
|
|
|
for (auto &node : nodes_) {
|
|
|
|
|
if (node.second->id() == id) {
|
|
|
|
|
return node.second.get();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -155,8 +158,6 @@ class Graph {
|
|
|
|
|
PADDLE_ENFORCE(node_set_.find(node) == node_set_.end());
|
|
|
|
|
nodes_[node].reset(node);
|
|
|
|
|
node_set_.insert(node);
|
|
|
|
|
PADDLE_ENFORCE(!id2node_.count(node->id()), "duplicate id %d", node->id());
|
|
|
|
|
id2node_[node->id()] = node;
|
|
|
|
|
return node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -166,7 +167,6 @@ class Graph {
|
|
|
|
|
std::map<std::string, std::function<void(void)>> attr_dels_;
|
|
|
|
|
std::map<ir::Node *, std::unique_ptr<ir::Node>> nodes_;
|
|
|
|
|
std::unordered_set<ir::Node *> node_set_;
|
|
|
|
|
std::map<int, Node *> id2node_;
|
|
|
|
|
int node_count_{0};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|