|
|
|
@ -40,6 +40,8 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
|
|
|
|
|
PADDLE_ENFORCE(graph);
|
|
|
|
|
PADDLE_ENFORCE(desc_);
|
|
|
|
|
// insert vars
|
|
|
|
|
// The `var2id` keeps a map from a variable's name to its Node-id, the Node-id
|
|
|
|
|
// will keep updating to its latest alias during the graph-building.
|
|
|
|
|
std::unordered_map<std::string, size_t> var2id;
|
|
|
|
|
auto &main_block = desc_->blocks(framework::kRootBlockIndex);
|
|
|
|
|
for (int i = 0; i < main_block.vars_size(); i++) {
|
|
|
|
@ -51,6 +53,15 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
|
|
|
|
|
var2id[var.name()] = v->id();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// The variables in a SSA can only write once, so if a variable is written
|
|
|
|
|
// multiple times(quite common in our ProgramDesc design), multiple alias
|
|
|
|
|
// Nodes of this variable will be created, and each will just write once.
|
|
|
|
|
|
|
|
|
|
// An set that keep all the names of the variables(the original, not alias)
|
|
|
|
|
// that have been written(as outputs). Once an Op's output variable hit the
|
|
|
|
|
// set, it should create a new alias and update the global alias for this
|
|
|
|
|
// variable. And that make a Data Flow Graph a SSA.
|
|
|
|
|
std::unordered_set<Node *> unique_written_vars;
|
|
|
|
|
for (int i = 0; i < main_block.ops_size(); i++) {
|
|
|
|
|
const auto &op = main_block.ops(i);
|
|
|
|
|
auto *o = graph->nodes.Create(Node::Type::kFunction);
|
|
|
|
@ -62,33 +73,33 @@ void FluidToDataFlowGraphPass::Run(DataFlowGraph *graph) {
|
|
|
|
|
o->SetPbMsg(op.SerializeAsString());
|
|
|
|
|
|
|
|
|
|
// set inputs and outputs
|
|
|
|
|
std::unordered_set<Node *> inlinks;
|
|
|
|
|
for (int j = 0; j < op.inputs_size(); j++) {
|
|
|
|
|
auto &in_var = op.inputs(j);
|
|
|
|
|
for (int k = 0; k < in_var.arguments_size(); k++) {
|
|
|
|
|
auto *in = graph->nodes.GetMutable(var2id.at(in_var.arguments(k)));
|
|
|
|
|
in->outlinks.push_back(o);
|
|
|
|
|
o->inlinks.push_back(in);
|
|
|
|
|
inlinks.insert(in);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int j = 0; j < op.outputs_size(); j++) {
|
|
|
|
|
auto &out_var = op.outputs(j);
|
|
|
|
|
for (int k = 0; k < out_var.arguments_size(); k++) {
|
|
|
|
|
auto *out = graph->nodes.GetMutable(var2id[out_var.arguments(k)]);
|
|
|
|
|
if (inlinks.count(out)) {
|
|
|
|
|
if (unique_written_vars.count(out)) {
|
|
|
|
|
// Loop found, for example, a = op(a), use SSA, change to a1 = op(a).
|
|
|
|
|
auto *out_alias = graph->nodes.Create(Node::Type::kValue);
|
|
|
|
|
out_alias->SetName(out->name());
|
|
|
|
|
out_alias->SetPbDesc(out->pb_desc());
|
|
|
|
|
out_alias->SetPbMsg(out->pb_msg());
|
|
|
|
|
var2id[out_alias->name()] = out_alias->id(); // update a -> a0
|
|
|
|
|
var2id[out_alias->name()] =
|
|
|
|
|
out_alias->id(); // update variable's alias Node
|
|
|
|
|
LOG(INFO) << "loop found in graph, create SSA alias node ["
|
|
|
|
|
<< out_alias->repr() << "] for [" << out->repr() << "]";
|
|
|
|
|
out = out_alias;
|
|
|
|
|
}
|
|
|
|
|
out->inlinks.push_back(o);
|
|
|
|
|
o->outlinks.push_back(out);
|
|
|
|
|
unique_written_vars.insert(out);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|