parent
4f01de6378
commit
e537634d16
@ -1,150 +0,0 @@
|
|||||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "paddle/fluid/framework/details/graph_print_pass.h"
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace framework {
|
|
||||||
namespace details {
|
|
||||||
|
|
||||||
class GraphvizVar : public GraphvizNode {
|
|
||||||
public:
|
|
||||||
GraphvizVar(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
|
|
||||||
friend std::ostream& operator<<(std::ostream& sout, const GraphvizVar& var) {
|
|
||||||
sout << "var_" << var.id_ << " [label=\"" << var.node_->Name() << "\"]"
|
|
||||||
<< std::endl;
|
|
||||||
return sout;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class GraphvizOp : public GraphvizNode {
|
|
||||||
public:
|
|
||||||
GraphvizOp(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
|
|
||||||
friend std::ostream& operator<<(std::ostream& sout, const GraphvizOp& op) {
|
|
||||||
sout << "op_" + std::to_string(op.id_) << " [label=\"" << op.node_->Name()
|
|
||||||
<< "\", shape=rect]" << std::endl;
|
|
||||||
sout << op.stream_.str();
|
|
||||||
return sout;
|
|
||||||
}
|
|
||||||
template <typename Callback>
|
|
||||||
void AddEdge(const Callback& cb) {
|
|
||||||
std::string op_name = "op_" + std::to_string(id_);
|
|
||||||
for (auto var : node_->inputs) {
|
|
||||||
std::string var_name = "var_" + std::to_string(cb(var));
|
|
||||||
stream_ << var_name << "->" << op_name << std::endl;
|
|
||||||
}
|
|
||||||
for (auto var : node_->outputs) {
|
|
||||||
std::string var_name = "var_" + std::to_string(cb(var));
|
|
||||||
stream_ << op_name << "->" << var_name << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Callback>
|
|
||||||
void AddCustomEdge(const Callback& cb) {
|
|
||||||
stream_ << cb() << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::ostringstream stream_;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T, typename Container>
|
|
||||||
std::vector<T*> FilterByNodeWrapper(const Container& con) {
|
|
||||||
std::vector<T*> ret;
|
|
||||||
for (auto& node : con) {
|
|
||||||
auto i = dynamic_cast<T*>(node.get());
|
|
||||||
if (i != nullptr) ret.emplace_back(i);
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
|
|
||||||
const ir::Graph& graph) const {
|
|
||||||
// Convert to GraphvizNode format
|
|
||||||
auto& graphviz_nodes = graph.Get<GraphvizNodes>(kGraphviz);
|
|
||||||
graphviz_nodes.clear();
|
|
||||||
std::unordered_map<ir::Node*, int> vars;
|
|
||||||
std::unordered_map<ir::Node*, GraphvizOp*> ops;
|
|
||||||
int var_id = 0;
|
|
||||||
int op_id = 0;
|
|
||||||
for (auto& node : graph.Nodes()) {
|
|
||||||
if (node->IsVar()) {
|
|
||||||
graphviz_nodes.emplace(new GraphvizVar(node, var_id));
|
|
||||||
vars.emplace(std::make_pair(node, var_id++));
|
|
||||||
} else if (node->IsOp()) {
|
|
||||||
std::unique_ptr<GraphvizOp> op(new GraphvizOp(node, op_id++));
|
|
||||||
ops[node] = op.get();
|
|
||||||
graphviz_nodes.emplace(std::move(op));
|
|
||||||
} else {
|
|
||||||
PADDLE_THROW("Unknown op type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect circle. Draw circle in different lines
|
|
||||||
std::vector<std::vector<ir::Node*>> circles;
|
|
||||||
const std::string kCircleEdge = "[color=red,penwidth=3.0]";
|
|
||||||
if (ir::FindCircleSubGraph(graph, &circles)) {
|
|
||||||
VLOG(3) << "Graph has circle! circles count : " << circles.size();
|
|
||||||
for (auto& circle : circles) {
|
|
||||||
for (size_t i = 0; i < circle.size() - 1; ++i) {
|
|
||||||
GraphvizOp* prev = ops[circle[i]];
|
|
||||||
GraphvizOp* next = ops[circle[i + 1]];
|
|
||||||
std::string prev_op = "op_" + std::to_string(prev->Id());
|
|
||||||
std::string next_op = "op_" + std::to_string(next->Id());
|
|
||||||
prev->AddCustomEdge([&]() -> std::string {
|
|
||||||
return prev_op + "->" + next_op + kCircleEdge;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return vars;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SSAGraphPrinterImpl::Print(const ir::Graph& graph,
|
|
||||||
std::ostream& sout) const {
|
|
||||||
auto vars = ToGraphvizNode(graph);
|
|
||||||
auto& nodes = graph.Get<GraphvizNodes>(kGraphviz);
|
|
||||||
|
|
||||||
sout << "digraph G {\n";
|
|
||||||
for (auto& var : FilterByNodeWrapper<GraphvizVar>(nodes)) {
|
|
||||||
sout << *var;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (auto& op : FilterByNodeWrapper<GraphvizOp>(nodes)) {
|
|
||||||
op->AddEdge([&vars](ir::Node* var) { return vars.at(var); });
|
|
||||||
sout << *op;
|
|
||||||
}
|
|
||||||
sout << "}\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<ir::Graph> SSAGraphPrintPass::ApplyImpl(
|
|
||||||
std::unique_ptr<ir::Graph> graph) const {
|
|
||||||
printer_.reset(new SSAGraphPrinterImpl());
|
|
||||||
std::unique_ptr<std::ostream> fout(
|
|
||||||
new std::ofstream(Get<std::string>(kGraphvizPath)));
|
|
||||||
PADDLE_ENFORCE(fout->good() == true, "Failed to open file.");
|
|
||||||
|
|
||||||
printer_->Print(*graph, *fout);
|
|
||||||
return graph;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace details
|
|
||||||
} // namespace framework
|
|
||||||
} // namespace paddle
|
|
||||||
|
|
||||||
REGISTER_PASS(graph_print_pass, paddle::framework::details::SSAGraphPrintPass)
|
|
||||||
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
|
|
@ -1,73 +0,0 @@
|
|||||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
#include <fstream>
|
|
||||||
#include <memory>
|
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace framework {
|
|
||||||
namespace details {
|
|
||||||
|
|
||||||
constexpr char kGraphvizPath[] = "debug_graphviz_path";
|
|
||||||
constexpr char kGraphviz[] = "graphviz";
|
|
||||||
|
|
||||||
// NOTE(dzhwinter): If the graph contains circles.
|
|
||||||
// the graph can not be topology sort.
|
|
||||||
// This printer will print the whole graph
|
|
||||||
// and highlight the circles. It's quite useful
|
|
||||||
// for debug the deadlock and circles.
|
|
||||||
class GraphvizNode {
|
|
||||||
public:
|
|
||||||
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
|
|
||||||
virtual ~GraphvizNode() = default;
|
|
||||||
|
|
||||||
int Id() const { return id_; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
ir::Node* node_;
|
|
||||||
int id_;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes;
|
|
||||||
|
|
||||||
class SSAGraphPrinter {
|
|
||||||
public:
|
|
||||||
virtual ~SSAGraphPrinter() {}
|
|
||||||
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
class SSAGraphPrinterImpl : public SSAGraphPrinter {
|
|
||||||
public:
|
|
||||||
void Print(const ir::Graph& graph, std::ostream& sout) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::unordered_map<ir::Node*, int> ToGraphvizNode(
|
|
||||||
const ir::Graph& graph) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
class SSAGraphPrintPass : public ir::Pass {
|
|
||||||
protected:
|
|
||||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
|
||||||
std::unique_ptr<ir::Graph> graph) const override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
mutable std::unique_ptr<SSAGraphPrinter> printer_;
|
|
||||||
};
|
|
||||||
} // namespace details
|
|
||||||
} // namespace framework
|
|
||||||
} // namespace paddle
|
|
@ -1,190 +0,0 @@
|
|||||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "paddle/fluid/framework/details/graph_print_pass.h"
|
|
||||||
#include "paddle/fluid/framework/details/graph_test_base.h"
|
|
||||||
|
|
||||||
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
|
|
||||||
paddle::framework::SumOpMaker);
|
|
||||||
REGISTER_OPERATOR(split, paddle::framework::DummyOp,
|
|
||||||
paddle::framework::SplitOpMaker);
|
|
||||||
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
|
|
||||||
paddle::framework::AssignOpMaker,
|
|
||||||
paddle::framework::DummyVarTypeInference);
|
|
||||||
|
|
||||||
/*
|
|
||||||
a @ b
|
|
||||||
c
|
|
||||||
d @ e
|
|
||||||
*/
|
|
||||||
|
|
||||||
using paddle::framework::ProgramDesc;
|
|
||||||
using paddle::framework::proto::VarType;
|
|
||||||
|
|
||||||
inline static ProgramDesc FillProgramDesc() {
|
|
||||||
ProgramDesc prog;
|
|
||||||
prog.MutableBlock(0)->Var("a")->SetType(VarType::LOD_TENSOR);
|
|
||||||
prog.MutableBlock(0)->Var("b")->SetType(VarType::LOD_TENSOR);
|
|
||||||
prog.MutableBlock(0)->Var("c")->SetType(VarType::LOD_TENSOR);
|
|
||||||
prog.MutableBlock(0)->Var("d")->SetType(VarType::LOD_TENSOR);
|
|
||||||
prog.MutableBlock(0)->Var("e")->SetType(VarType::LOD_TENSOR);
|
|
||||||
{
|
|
||||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
|
||||||
op->SetType("sum");
|
|
||||||
op->SetInput("X", {"a", "b"});
|
|
||||||
op->SetOutput("Out", {"c"});
|
|
||||||
}
|
|
||||||
{
|
|
||||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
|
||||||
op->SetType("split");
|
|
||||||
op->SetInput("X", {"c"});
|
|
||||||
op->SetOutput("Out", {"d", "e"});
|
|
||||||
}
|
|
||||||
{
|
|
||||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
|
||||||
op->SetType("sum");
|
|
||||||
op->SetInput("X", {"d", "e"});
|
|
||||||
op->SetOutput("Out", {"d"});
|
|
||||||
}
|
|
||||||
{
|
|
||||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
|
||||||
op->SetType("assign");
|
|
||||||
op->SetInput("X", {"d"});
|
|
||||||
op->SetOutput("Out", {"d"});
|
|
||||||
}
|
|
||||||
return prog;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace framework {
|
|
||||||
namespace details {
|
|
||||||
|
|
||||||
TEST(SSAGraphPrinter, Normal) {
|
|
||||||
auto program = FillProgramDesc();
|
|
||||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
|
|
||||||
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
|
|
||||||
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
|
|
||||||
|
|
||||||
// redirect debug graph to a file.
|
|
||||||
constexpr char graph_path[] = "graph_print_pass.txt";
|
|
||||||
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
|
|
||||||
PADDLE_ENFORCE(fout->good());
|
|
||||||
printer->Print(*graph, *fout);
|
|
||||||
}
|
|
||||||
|
|
||||||
using ir::Graph;
|
|
||||||
using ir::Node;
|
|
||||||
void BuildCircleGraph(Graph* g) {
|
|
||||||
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
|
|
||||||
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
|
|
||||||
|
|
||||||
o1->outputs.push_back(v1);
|
|
||||||
o1->inputs.push_back(v1);
|
|
||||||
v1->inputs.push_back(o1);
|
|
||||||
v1->outputs.push_back(o1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BuildCircleGraph2(Graph* g) {
|
|
||||||
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
|
|
||||||
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
|
|
||||||
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
|
|
||||||
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
|
|
||||||
|
|
||||||
o1->outputs.push_back(v1);
|
|
||||||
o2->inputs.push_back(v1);
|
|
||||||
v1->inputs.push_back(o1);
|
|
||||||
v1->outputs.push_back(o2);
|
|
||||||
|
|
||||||
o2->outputs.push_back(v2);
|
|
||||||
o1->inputs.push_back(v2);
|
|
||||||
v2->inputs.push_back(o2);
|
|
||||||
v2->outputs.push_back(o1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void BuildNoCircleGraph(Graph* g) {
|
|
||||||
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
|
|
||||||
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
|
|
||||||
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
|
|
||||||
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
|
|
||||||
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
|
|
||||||
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
|
|
||||||
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
|
|
||||||
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
|
|
||||||
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
|
|
||||||
|
|
||||||
// o1->v1->o2
|
|
||||||
o1->outputs.push_back(v1);
|
|
||||||
o2->inputs.push_back(v1);
|
|
||||||
v1->inputs.push_back(o1);
|
|
||||||
v1->outputs.push_back(o2);
|
|
||||||
// o2->v2->o3
|
|
||||||
// o2->v2->o4
|
|
||||||
o2->outputs.push_back(v2);
|
|
||||||
o3->inputs.push_back(v2);
|
|
||||||
o4->inputs.push_back(v2);
|
|
||||||
v2->inputs.push_back(o2);
|
|
||||||
v2->outputs.push_back(o3);
|
|
||||||
v2->outputs.push_back(o4);
|
|
||||||
// o2->v3->o5
|
|
||||||
o2->outputs.push_back(v3);
|
|
||||||
o5->inputs.push_back(v3);
|
|
||||||
v3->inputs.push_back(o2);
|
|
||||||
v3->outputs.push_back(o5);
|
|
||||||
// o3-v4->o5
|
|
||||||
o3->outputs.push_back(v4);
|
|
||||||
o5->inputs.push_back(v4);
|
|
||||||
v4->inputs.push_back(o3);
|
|
||||||
v4->outputs.push_back(o5);
|
|
||||||
|
|
||||||
// o2->v3->o1
|
|
||||||
v3->outputs.push_back(o1);
|
|
||||||
o1->inputs.push_back(v3);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(SSAGraphPrinter, SimpleCircle) {
|
|
||||||
ProgramDesc prog;
|
|
||||||
|
|
||||||
Graph graph(prog);
|
|
||||||
BuildCircleGraph(&graph);
|
|
||||||
ASSERT_TRUE(HasCircle(graph));
|
|
||||||
|
|
||||||
graph.Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
|
|
||||||
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
|
|
||||||
|
|
||||||
// redirect debug graph to a file.
|
|
||||||
constexpr char graph_path[] = "graph_print_pass_simple_circle.txt";
|
|
||||||
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
|
|
||||||
PADDLE_ENFORCE(fout->good());
|
|
||||||
printer->Print(graph, *fout);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(SSAGraphPrinter, ComplexCircle) {
|
|
||||||
ProgramDesc prog;
|
|
||||||
Graph graph(prog);
|
|
||||||
BuildCircleGraph2(&graph);
|
|
||||||
ASSERT_TRUE(HasCircle(graph));
|
|
||||||
|
|
||||||
graph.Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
|
|
||||||
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
|
|
||||||
|
|
||||||
// redirect debug graph to a file.
|
|
||||||
constexpr char graph_path[] = "graph_print_pass_complex_circle.txt";
|
|
||||||
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
|
|
||||||
PADDLE_ENFORCE(fout->good());
|
|
||||||
printer->Print(graph, *fout);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace details
|
|
||||||
} // namespace framework
|
|
||||||
} // namespace paddle
|
|
Loading…
Reference in new issue