parent
4f01de6378
commit
e537634d16
@ -1,150 +0,0 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/graph_print_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/ir/graph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class GraphvizVar : public GraphvizNode {
|
||||
public:
|
||||
GraphvizVar(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
|
||||
friend std::ostream& operator<<(std::ostream& sout, const GraphvizVar& var) {
|
||||
sout << "var_" << var.id_ << " [label=\"" << var.node_->Name() << "\"]"
|
||||
<< std::endl;
|
||||
return sout;
|
||||
}
|
||||
};
|
||||
|
||||
class GraphvizOp : public GraphvizNode {
|
||||
public:
|
||||
GraphvizOp(ir::Node* n, const int& i) : GraphvizNode(n, i) {}
|
||||
friend std::ostream& operator<<(std::ostream& sout, const GraphvizOp& op) {
|
||||
sout << "op_" + std::to_string(op.id_) << " [label=\"" << op.node_->Name()
|
||||
<< "\", shape=rect]" << std::endl;
|
||||
sout << op.stream_.str();
|
||||
return sout;
|
||||
}
|
||||
template <typename Callback>
|
||||
void AddEdge(const Callback& cb) {
|
||||
std::string op_name = "op_" + std::to_string(id_);
|
||||
for (auto var : node_->inputs) {
|
||||
std::string var_name = "var_" + std::to_string(cb(var));
|
||||
stream_ << var_name << "->" << op_name << std::endl;
|
||||
}
|
||||
for (auto var : node_->outputs) {
|
||||
std::string var_name = "var_" + std::to_string(cb(var));
|
||||
stream_ << op_name << "->" << var_name << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Callback>
|
||||
void AddCustomEdge(const Callback& cb) {
|
||||
stream_ << cb() << std::endl;
|
||||
}
|
||||
|
||||
private:
|
||||
std::ostringstream stream_;
|
||||
};
|
||||
|
||||
template <typename T, typename Container>
|
||||
std::vector<T*> FilterByNodeWrapper(const Container& con) {
|
||||
std::vector<T*> ret;
|
||||
for (auto& node : con) {
|
||||
auto i = dynamic_cast<T*>(node.get());
|
||||
if (i != nullptr) ret.emplace_back(i);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::unordered_map<ir::Node*, int> SSAGraphPrinterImpl::ToGraphvizNode(
|
||||
const ir::Graph& graph) const {
|
||||
// Convert to GraphvizNode format
|
||||
auto& graphviz_nodes = graph.Get<GraphvizNodes>(kGraphviz);
|
||||
graphviz_nodes.clear();
|
||||
std::unordered_map<ir::Node*, int> vars;
|
||||
std::unordered_map<ir::Node*, GraphvizOp*> ops;
|
||||
int var_id = 0;
|
||||
int op_id = 0;
|
||||
for (auto& node : graph.Nodes()) {
|
||||
if (node->IsVar()) {
|
||||
graphviz_nodes.emplace(new GraphvizVar(node, var_id));
|
||||
vars.emplace(std::make_pair(node, var_id++));
|
||||
} else if (node->IsOp()) {
|
||||
std::unique_ptr<GraphvizOp> op(new GraphvizOp(node, op_id++));
|
||||
ops[node] = op.get();
|
||||
graphviz_nodes.emplace(std::move(op));
|
||||
} else {
|
||||
PADDLE_THROW("Unknown op type");
|
||||
}
|
||||
}
|
||||
|
||||
// Detect circle. Draw circle in different lines
|
||||
std::vector<std::vector<ir::Node*>> circles;
|
||||
const std::string kCircleEdge = "[color=red,penwidth=3.0]";
|
||||
if (ir::FindCircleSubGraph(graph, &circles)) {
|
||||
VLOG(3) << "Graph has circle! circles count : " << circles.size();
|
||||
for (auto& circle : circles) {
|
||||
for (size_t i = 0; i < circle.size() - 1; ++i) {
|
||||
GraphvizOp* prev = ops[circle[i]];
|
||||
GraphvizOp* next = ops[circle[i + 1]];
|
||||
std::string prev_op = "op_" + std::to_string(prev->Id());
|
||||
std::string next_op = "op_" + std::to_string(next->Id());
|
||||
prev->AddCustomEdge([&]() -> std::string {
|
||||
return prev_op + "->" + next_op + kCircleEdge;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
return vars;
|
||||
}
|
||||
|
||||
void SSAGraphPrinterImpl::Print(const ir::Graph& graph,
|
||||
std::ostream& sout) const {
|
||||
auto vars = ToGraphvizNode(graph);
|
||||
auto& nodes = graph.Get<GraphvizNodes>(kGraphviz);
|
||||
|
||||
sout << "digraph G {\n";
|
||||
for (auto& var : FilterByNodeWrapper<GraphvizVar>(nodes)) {
|
||||
sout << *var;
|
||||
}
|
||||
|
||||
for (auto& op : FilterByNodeWrapper<GraphvizOp>(nodes)) {
|
||||
op->AddEdge([&vars](ir::Node* var) { return vars.at(var); });
|
||||
sout << *op;
|
||||
}
|
||||
sout << "}\n";
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> SSAGraphPrintPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
printer_.reset(new SSAGraphPrinterImpl());
|
||||
std::unique_ptr<std::ostream> fout(
|
||||
new std::ofstream(Get<std::string>(kGraphvizPath)));
|
||||
PADDLE_ENFORCE(fout->good() == true, "Failed to open file.");
|
||||
|
||||
printer_->Print(*graph, *fout);
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(graph_print_pass, paddle::framework::details::SSAGraphPrintPass)
|
||||
.RequirePassAttr(paddle::framework::details::kGraphvizPath);
|
@ -1,73 +0,0 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "paddle/fluid/framework/details/multi_devices_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
constexpr char kGraphvizPath[] = "debug_graphviz_path";
|
||||
constexpr char kGraphviz[] = "graphviz";
|
||||
|
||||
// NOTE(dzhwinter): If the graph contains circles.
|
||||
// the graph can not be topology sort.
|
||||
// This printer will print the whole graph
|
||||
// and highlight the circles. It's quite useful
|
||||
// for debug the deadlock and circles.
|
||||
class GraphvizNode {
|
||||
public:
|
||||
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
|
||||
virtual ~GraphvizNode() = default;
|
||||
|
||||
int Id() const { return id_; }
|
||||
|
||||
protected:
|
||||
ir::Node* node_;
|
||||
int id_;
|
||||
};
|
||||
|
||||
typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes;
|
||||
|
||||
class SSAGraphPrinter {
|
||||
public:
|
||||
virtual ~SSAGraphPrinter() {}
|
||||
virtual void Print(const ir::Graph& graph, std::ostream& sout) const = 0;
|
||||
};
|
||||
|
||||
class SSAGraphPrinterImpl : public SSAGraphPrinter {
|
||||
public:
|
||||
void Print(const ir::Graph& graph, std::ostream& sout) const override;
|
||||
|
||||
private:
|
||||
std::unordered_map<ir::Node*, int> ToGraphvizNode(
|
||||
const ir::Graph& graph) const;
|
||||
};
|
||||
|
||||
class SSAGraphPrintPass : public ir::Pass {
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
private:
|
||||
mutable std::unique_ptr<SSAGraphPrinter> printer_;
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,190 +0,0 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/graph_print_pass.h"
|
||||
#include "paddle/fluid/framework/details/graph_test_base.h"
|
||||
|
||||
REGISTER_OPERATOR(sum, paddle::framework::DummyOp,
|
||||
paddle::framework::SumOpMaker);
|
||||
REGISTER_OPERATOR(split, paddle::framework::DummyOp,
|
||||
paddle::framework::SplitOpMaker);
|
||||
REGISTER_OPERATOR(assign, paddle::framework::DummyOp,
|
||||
paddle::framework::AssignOpMaker,
|
||||
paddle::framework::DummyVarTypeInference);
|
||||
|
||||
/*
|
||||
a @ b
|
||||
c
|
||||
d @ e
|
||||
*/
|
||||
|
||||
using paddle::framework::ProgramDesc;
|
||||
using paddle::framework::proto::VarType;
|
||||
|
||||
inline static ProgramDesc FillProgramDesc() {
|
||||
ProgramDesc prog;
|
||||
prog.MutableBlock(0)->Var("a")->SetType(VarType::LOD_TENSOR);
|
||||
prog.MutableBlock(0)->Var("b")->SetType(VarType::LOD_TENSOR);
|
||||
prog.MutableBlock(0)->Var("c")->SetType(VarType::LOD_TENSOR);
|
||||
prog.MutableBlock(0)->Var("d")->SetType(VarType::LOD_TENSOR);
|
||||
prog.MutableBlock(0)->Var("e")->SetType(VarType::LOD_TENSOR);
|
||||
{
|
||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
||||
op->SetType("sum");
|
||||
op->SetInput("X", {"a", "b"});
|
||||
op->SetOutput("Out", {"c"});
|
||||
}
|
||||
{
|
||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
||||
op->SetType("split");
|
||||
op->SetInput("X", {"c"});
|
||||
op->SetOutput("Out", {"d", "e"});
|
||||
}
|
||||
{
|
||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
||||
op->SetType("sum");
|
||||
op->SetInput("X", {"d", "e"});
|
||||
op->SetOutput("Out", {"d"});
|
||||
}
|
||||
{
|
||||
auto* op = prog.MutableBlock(0)->AppendOp();
|
||||
op->SetType("assign");
|
||||
op->SetInput("X", {"d"});
|
||||
op->SetOutput("Out", {"d"});
|
||||
}
|
||||
return prog;
|
||||
}
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
TEST(SSAGraphPrinter, Normal) {
|
||||
auto program = FillProgramDesc();
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
|
||||
graph->Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
|
||||
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
|
||||
|
||||
// redirect debug graph to a file.
|
||||
constexpr char graph_path[] = "graph_print_pass.txt";
|
||||
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
|
||||
PADDLE_ENFORCE(fout->good());
|
||||
printer->Print(*graph, *fout);
|
||||
}
|
||||
|
||||
using ir::Graph;
|
||||
using ir::Node;
|
||||
void BuildCircleGraph(Graph* g) {
|
||||
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
|
||||
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
|
||||
|
||||
o1->outputs.push_back(v1);
|
||||
o1->inputs.push_back(v1);
|
||||
v1->inputs.push_back(o1);
|
||||
v1->outputs.push_back(o1);
|
||||
}
|
||||
|
||||
void BuildCircleGraph2(Graph* g) {
|
||||
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
|
||||
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
|
||||
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
|
||||
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
|
||||
|
||||
o1->outputs.push_back(v1);
|
||||
o2->inputs.push_back(v1);
|
||||
v1->inputs.push_back(o1);
|
||||
v1->outputs.push_back(o2);
|
||||
|
||||
o2->outputs.push_back(v2);
|
||||
o1->inputs.push_back(v2);
|
||||
v2->inputs.push_back(o2);
|
||||
v2->outputs.push_back(o1);
|
||||
}
|
||||
|
||||
void BuildNoCircleGraph(Graph* g) {
|
||||
ir::Node* o1 = g->CreateEmptyNode("op1", Node::Type::kOperation);
|
||||
ir::Node* o2 = g->CreateEmptyNode("op2", Node::Type::kOperation);
|
||||
ir::Node* o3 = g->CreateEmptyNode("op3", Node::Type::kOperation);
|
||||
ir::Node* o4 = g->CreateEmptyNode("op4", Node::Type::kOperation);
|
||||
ir::Node* o5 = g->CreateEmptyNode("op5", Node::Type::kOperation);
|
||||
ir::Node* v1 = g->CreateEmptyNode("var1", Node::Type::kVariable);
|
||||
ir::Node* v2 = g->CreateEmptyNode("var2", Node::Type::kVariable);
|
||||
ir::Node* v3 = g->CreateEmptyNode("var3", Node::Type::kVariable);
|
||||
ir::Node* v4 = g->CreateEmptyNode("var4", Node::Type::kVariable);
|
||||
|
||||
// o1->v1->o2
|
||||
o1->outputs.push_back(v1);
|
||||
o2->inputs.push_back(v1);
|
||||
v1->inputs.push_back(o1);
|
||||
v1->outputs.push_back(o2);
|
||||
// o2->v2->o3
|
||||
// o2->v2->o4
|
||||
o2->outputs.push_back(v2);
|
||||
o3->inputs.push_back(v2);
|
||||
o4->inputs.push_back(v2);
|
||||
v2->inputs.push_back(o2);
|
||||
v2->outputs.push_back(o3);
|
||||
v2->outputs.push_back(o4);
|
||||
// o2->v3->o5
|
||||
o2->outputs.push_back(v3);
|
||||
o5->inputs.push_back(v3);
|
||||
v3->inputs.push_back(o2);
|
||||
v3->outputs.push_back(o5);
|
||||
// o3-v4->o5
|
||||
o3->outputs.push_back(v4);
|
||||
o5->inputs.push_back(v4);
|
||||
v4->inputs.push_back(o3);
|
||||
v4->outputs.push_back(o5);
|
||||
|
||||
// o2->v3->o1
|
||||
v3->outputs.push_back(o1);
|
||||
o1->inputs.push_back(v3);
|
||||
}
|
||||
|
||||
TEST(SSAGraphPrinter, SimpleCircle) {
|
||||
ProgramDesc prog;
|
||||
|
||||
Graph graph(prog);
|
||||
BuildCircleGraph(&graph);
|
||||
ASSERT_TRUE(HasCircle(graph));
|
||||
|
||||
graph.Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
|
||||
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
|
||||
|
||||
// redirect debug graph to a file.
|
||||
constexpr char graph_path[] = "graph_print_pass_simple_circle.txt";
|
||||
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
|
||||
PADDLE_ENFORCE(fout->good());
|
||||
printer->Print(graph, *fout);
|
||||
}
|
||||
|
||||
TEST(SSAGraphPrinter, ComplexCircle) {
|
||||
ProgramDesc prog;
|
||||
Graph graph(prog);
|
||||
BuildCircleGraph2(&graph);
|
||||
ASSERT_TRUE(HasCircle(graph));
|
||||
|
||||
graph.Set<GraphvizNodes>(kGraphviz, new GraphvizNodes);
|
||||
std::unique_ptr<SSAGraphPrinter> printer(new SSAGraphPrinterImpl);
|
||||
|
||||
// redirect debug graph to a file.
|
||||
constexpr char graph_path[] = "graph_print_pass_complex_circle.txt";
|
||||
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_path));
|
||||
PADDLE_ENFORCE(fout->good());
|
||||
printer->Print(graph, *fout);
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue