You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/debug/draw.cc

656 lines
18 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "debug/draw.h"
#include <algorithm>
#include <iostream>
#include <iterator>
#include <map>
#include <vector>
#include <string>
#include "ir/meta_func_graph.h"
#include "ir/param_info.h"
#include "ir/primitive.h"
#include "ir/graph_utils.h"
#include "utils/utils.h"
#include "frontend/operator/composite/composite.h"
#include "ir/tensor.h"
namespace mindspore {
// namespace to support debug utils
namespace draw {
namespace {
// Only for ValueNode
std::string ValueType(const ValueNodePtr &node) {
if (node == nullptr) {
return "";
}
auto v = node->value();
return v->type_name();
}
std::string ReplaceSpecialChar(const std::string &str) {
std::ostringstream oss;
for (size_t i = 0; i < str.size(); i++) {
if (str[i] == '<') {
oss << "";
} else if (str[i] == '>') {
oss << "";
} else {
oss << str[i];
}
}
return oss.str();
}
} // namespace
// API of debug utils
void DrawNodes(const std::vector<AnfNodePtr> &nodes, OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs,
bool is_user) {
if (sub_graphs == nullptr) {
return;
}
for (auto &nd : nodes) {
MS_EXCEPTION_IF_NULL(nd);
auto sub_graph = nd->func_graph();
if (sub_graph != nullptr) {
auto gsub = (*sub_graphs)[sub_graph];
if (gsub == nullptr) {
if (is_user) {
gsub = std::make_shared<ModelDigraph>(sub_graph->ToString());
} else {
gsub = std::make_shared<Digraph>(sub_graph->ToString());
}
(*sub_graphs)[sub_graph] = gsub;
}
if (!nd->isa<Parameter>()) {
gsub->Node(nd);
}
}
}
}
void DrawValueNodes(const std::vector<AnfNodePtr> &nodes,
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> *sub_graphs) {
if (sub_graphs == nullptr) {
return;
}
int dup_idx = 0;
for (auto &nd : nodes) {
for (auto &t : SuccIncoming(nd)) {
MS_EXCEPTION_IF_NULL(t);
MS_EXCEPTION_IF_NULL(nd);
if (t->isa<ValueNode>() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) {
(*sub_graphs)[nd->func_graph()]->Node(t, dup_idx);
dup_idx++;
} else if (t->isa<Parameter>() && (*sub_graphs).find(t->func_graph()) != (*sub_graphs).end()) {
(*sub_graphs)[t->func_graph()]->Node(t, dup_idx);
dup_idx++;
}
}
}
}
void DrawEdges(const std::vector<AnfNodePtr> &nodes, const std::shared_ptr<BaseDigraph> &digraph, bool is_user) {
if (digraph == nullptr) {
return;
}
int dup_idx = 0;
int offset = 0;
if (is_user) {
offset = 1;
}
// Draw edge
for (auto &nd : nodes) {
auto succs = SuccIncoming(nd);
auto num = succs.size();
for (size_t i = 0; i < num; i++) {
auto &t = succs.at(i);
MS_EXCEPTION_IF_NULL(t);
if (t->isa<ValueNode>() || t->isa<Parameter>()) {
if ((!is_user) || (i != 0)) {
// `SizeToInt(i) - offset` is just for printing as text
digraph->Edge(t, nd, SizeToInt(i) - offset, dup_idx);
}
if (IsValueNode<FuncGraph>(t)) {
auto const_graph = GetValueNode<FuncGraphPtr>(t);
digraph->Edge(t, const_graph, dup_idx);
}
dup_idx++;
} else {
digraph->Edge(t, nd, SizeToInt(i) - offset);
}
}
}
}
void DrawByOpt(std::string filename, const FuncGraphPtr &func_graph, bool is_user) {
if (func_graph == nullptr) {
return;
}
auto ret = func_graph->get_return();
auto nodes = DeepScopedGraphSearch(ret);
std::shared_ptr<BaseDigraph> digraph;
OrderedMap<FuncGraphPtr, std::shared_ptr<BaseDigraph>> sub_graphs;
ChangeFileMode(filename, S_IRWXU);
if (is_user) {
digraph = std::make_shared<ModelDigraph>("mindspore", filename);
} else {
digraph = std::make_shared<Digraph>("mindspore", filename);
}
MS_EXCEPTION_IF_NULL(digraph);
digraph->Start();
// Draw nodes
DrawNodes(nodes, &sub_graphs, is_user);
// Draw ValueNodes on CNodes
DrawValueNodes(nodes, &sub_graphs);
// Draw subgraph
for (const auto &gsub : sub_graphs) {
digraph->SubGraph(gsub.first, gsub.second);
}
// Draw edge
DrawEdges(nodes, digraph, is_user);
digraph->End();
// set file mode to read only by user
ChangeFileMode(filename, S_IRUSR);
}
#ifdef ENABLE_DUMP_IR
void Draw(const std::string &filename, const FuncGraphPtr &func_graph) {
const std::string dot_suffix = ".dot";
std::string filename_with_suffix =
(filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename;
DrawByOpt(filename_with_suffix, func_graph, false);
}
void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) {
DrawByOpt(filename, func_graph, true);
}
#else
void Draw(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The functionality of dumping function graph IR in graphviz dot format is disabled, "
<< "please recompile source to enable it. See help of building script.";
}
void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) {
static bool already_printed = false;
if (already_printed) {
return;
}
already_printed = true;
MS_LOG(WARNING) << "The functionality of dumping function graph IR in graphviz dot format is disabled, "
<< "please recompile source to enable it. See help of building script.";
}
#endif
std::string Graphviz::Shape(AnfNodePtr node) {
if (node == nullptr) {
return "";
}
if (node->isa<CNode>()) {
return "plaintext";
}
if (node->isa<Parameter>()) {
return "octagon";
}
if (IsValueNode<FuncGraph>(node)) {
return "oval";
}
return "plaintext";
}
std::string Graphviz::Color(const AnfNodePtr &node) {
if (node == nullptr) {
return "";
}
if (node->isa<CNode>()) {
return "cornsilk";
}
if (node->isa<Parameter>()) {
return "paleturquoise";
}
if (IsValueNode<FuncGraph>(node)) {
return "palegreen";
}
return "lavender";
}
void BaseDigraph::Start() {
buffer_ << "digraph " << name_ << " {" << std::endl;
buffer_ << "compound=true" << std::endl;
}
void BaseDigraph::Head(const AnfNodePtr &node, int id) {
if (node == nullptr) {
return;
}
buffer_ << "node" << node << "_" << id;
if (node->isa<CNode>() || (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node))) {
buffer_ << ":core";
}
}
void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) {
if (node == nullptr) {
return;
}
buffer_ << "node" << node << "_" << id;
buffer_ << ":" << idx;
}
void BaseDigraph::Tail(const FuncGraphPtr &func_graph) {
if (func_graph == nullptr) {
return;
}
buffer_ << "node" << func_graph->get_return() << "_" << 0;
}
void BaseDigraph::Edge(AnfNodePtr start, FuncGraphPtr end, int id_start) {
Head(start, id_start);
buffer_ << "->";
Tail(end);
buffer_ << "[lhead=cluster_" << end;
buffer_ << ",dir=both,arrowhead=dot,style=filled,color=blue]";
buffer_ << std::endl;
}
void BaseDigraph::End() {
buffer_ << "}" << std::endl;
if (fout_.is_open()) {
fout_ << buffer_.str();
}
}
void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << "parameters_" << key << "[shape=plaintext ";
buffer_ << "label=<<table bgcolor='paleturquoise' cellspacing='0' cellborder='1' border='0'>";
buffer_ << "<tr><td>parameters</td></tr>";
int count = 0;
for (auto &parameter : key->parameters()) {
buffer_ << "<tr><td>";
buffer_ << parameter->ToString();
auto param = parameter->cast<ParameterPtr>();
if (param->has_default()) {
auto tensor_v = param->default_param();
if (tensor_v && tensor_v->isa<tensor::Tensor>()) {
auto tensor = tensor_v->cast<tensor::TensorPtr>();
auto &shape = tensor->shape();
std::ostringstream shape_str;
std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(shape_str, ","));
buffer_ << "[" << shape_str.str() << "]";
}
}
buffer_ << "</td></tr>";
count++;
// wrap the text.
if (count % 10 == 0) {
buffer_ << "\n";
}
}
buffer_ << "</table>>,];";
}
void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr<BaseDigraph> &gsub) {
if (key == nullptr || gsub == nullptr) {
return;
}
std::string label = key->debug_info()->get_full_name();
if (label.empty()) {
label = gsub->name();
}
std::string label_managed = "[managed]";
if (key->manager() == nullptr) {
label_managed = "[not managed]";
}
label += label_managed;
gsub->FuncGraphParameters(key);
buffer_ << "subgraph cluster_" << key << "{" << std::endl;
buffer_ << "id=cluster_" << key << std::endl;
buffer_ << "label=\"" << label << "\"" << std::endl;
buffer_ << "fontname=\"Courier New\"" << std::endl;
buffer_ << gsub->buffer().str();
buffer_ << "}" << std::endl;
}
Digraph::~Digraph() {
try {
if (fout_.is_open()) {
fout_.close();
}
} catch (const std::exception &e) {
MS_LOG(ERROR) << "Exception when closing file " << filename_;
}
}
static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) {
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
(void)str.replace(start_pos, from.length(), to);
start_pos += to.length(); // Handles case where 'to' is a substring of 'from'
}
return str;
}
static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) {
MS_EXCEPTION_IF_NULL(graph_obj);
graph_obj->buffer() << "label=<<table port='core' cellborder='0' cellspacing='2' bgcolor='" << graph_obj->Color(node)
<< "'>";
graph_obj->buffer() << "<tr><td bgcolor='white'>";
graph_obj->buffer() << ValueType(node);
graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "<tr><td>";
if (IsValueNode<MetaFuncGraph>(node)) {
graph_obj->buffer() << node->value()->cast<MetaFuncGraphPtr>()->name();
} else if (IsValueNode<parse::NameSpace>(node)) {
graph_obj->buffer() << node->value()->cast<parse::NameSpacePtr>()->name();
} else if (IsValueNode<parse::Symbol>(node)) {
graph_obj->buffer() << ReplaceSpecialChar(node->value()->cast<parse::SymbolPtr>()->name());
} else {
std::ostringstream ss;
ss << node->value()->ToString();
std::string s = ReplaceAll(ss.str(), ", ", "<br/>");
graph_obj->buffer() << s;
ValuePtr value = node->value();
if (value->isa<Primitive>()) {
PrimitivePtr primitive = value->cast<PrimitivePtr>();
graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "<tr><td align='left'>";
if (!primitive->instance_name().empty()) {
graph_obj->buffer() << "instance name:"
<< " ";
graph_obj->buffer() << primitive->instance_name();
graph_obj->buffer() << "<br/>";
}
auto attrs = primitive->attrs();
if (attrs.size() > 0) {
graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "<tr><td align='left'>";
int i = 0;
for (const auto &attr : attrs) {
if (i != 0) {
graph_obj->buffer() << "<br/>";
}
graph_obj->buffer() << attr.first << " ";
graph_obj->buffer() << attr.second->ToString();
i++;
}
}
}
}
graph_obj->buffer() << "</td></tr>";
graph_obj->buffer() << "</table>>,";
}
static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) {
return;
}
auto distributed_operation_info = node->user_data<parallel::OperatorInfo>();
if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) {
auto num = node->inputs().size();
graph_obj->buffer() << "<tr><td colspan='" << num << "' ";
graph_obj->buffer() << "bgcolor='" << graph_obj->Color(node) << "'>";
std::vector<ValuePtr> temp = {MakeValue(strategyPtr->GetInputStage()), MakeValue(strategyPtr->GetInputDim())};
ValueTuplePtr strategy_tuple = std::make_shared<ValueTuple>(temp);
graph_obj->buffer() << "Strategy " << strategy_tuple->ToString();
graph_obj->buffer() << "</td></tr>" << std::endl;
}
}
}
static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr || node->size() == 0) {
return;
}
auto num = node->size();
bool is_modelgraph = false;
if (typeid(*graph_obj) == typeid(ModelDigraph)) {
is_modelgraph = true;
num -= 1;
}
graph_obj->buffer() << "label=<<table port='core'>" << std::endl;
// Draw ports for CNode
if (num > 0) {
graph_obj->buffer() << "<tr>";
for (size_t i = 0; i < num; i++) {
graph_obj->buffer() << "<td port='" << i << "'>" << i << "</td>";
}
graph_obj->buffer() << "</tr>" << std::endl;
}
// Draw op name
graph_obj->buffer() << "<tr><td";
if (num > 0) {
graph_obj->buffer() << " colspan='" << num << "'";
}
graph_obj->buffer() << " bgcolor='" << graph_obj->Color(node) << "'>";
if (IsValueNode<Primitive>(node->input(0)) && is_modelgraph) {
auto primitive = GetValueNode<PrimitivePtr>(node->input(0));
graph_obj->buffer() << ReplaceAll(primitive->ToString(), ", ", "<br/>");
auto attrs = primitive->attrs();
if (attrs.size() > 0) {
graph_obj->buffer() << "</td></tr>" << std::endl << "<tr><td";
// Draw attrs
if (num > 0) {
graph_obj->buffer() << " colspan='" << num << "'";
}
graph_obj->buffer() << ">";
int i = 0;
for (auto &attr : attrs) {
if (i != 0) {
graph_obj->buffer() << "<br/>";
}
graph_obj->buffer() << attr.first << " " << attr.second->ToString();
i++;
}
}
graph_obj->buffer() << "CNode";
} else {
graph_obj->buffer() << "CNode(" << node->ToString() << ")";
}
graph_obj->buffer() << "</td></tr>" << std::endl;
DrawParallelInfo(graph_obj, node);
graph_obj->buffer() << "</table>>,";
}
void Digraph::Node(AnfNodePtr node, int id) {
if (node == nullptr) {
return;
}
buffer_ << "node" << node << "_" << id;
buffer_ << "[";
// set fontname
buffer_ << "fontname=\"Courier New\",";
// set label and shape
buffer_ << "shape=" << Shape(node) << ",";
if (node->isa<CNode>()) {
DrawCNode(this, node->cast<CNodePtr>());
} else if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
DrawValueNode(this, node->cast<ValueNodePtr>());
} else {
buffer_ << "label=\"" << node->ToString();
if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr nextNet = GetValueNode<FuncGraphPtr>(node);
std::string nextNetName = nextNet->debug_info()->get_full_name();
if (!nextNetName.empty()) {
buffer_ << "[" << nextNet->debug_info()->get_full_name().c_str() << "]";
}
}
buffer_ << "\","
<< "style=filled,fillcolor=" << Color(node) << ",";
}
// Set URL for func graph
if (IsValueNode<FuncGraph>(node)) {
buffer_ << "URL=\"#cluster_" << GetValueNode(node) << "\",";
}
buffer_ << "]" << std::endl;
}
void Digraph::Edge(AnfNodePtr start, AnfNodePtr end, int idx, int id_start) {
if (start == nullptr || end == nullptr) {
return;
}
Head(start, id_start);
buffer_ << "->";
Tail(end, idx);
buffer_ << "[arrowhead=vee,";
// check how many inputs for end
if (end->isa<CNode>()) {
auto cnode = end->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto num = cnode->inputs().size();
if (idx == 0 && num > 1) {
buffer_ << "style=dashed";
}
}
buffer_ << "]" << std::endl;
}
ModelDigraph::~ModelDigraph() {
try {
if (fout_.is_open()) {
fout_.close();
}
} catch (const std::exception &e) {
MS_LOG(ERROR) << "exception when closing file " << filename_;
}
}
std::string ModelDigraph::Shape(AnfNodePtr node) {
if (node == nullptr) {
return "";
}
if (node->isa<CNode>()) {
return "plaintext";
}
if (node->isa<Parameter>()) {
return "ellipse";
}
if (IsValueNode<FuncGraph>(node)) {
return "oval";
}
return "plaintext";
}
void ModelDigraph::Node(AnfNodePtr node, int id) {
if (node == nullptr) {
return;
}
if (IsValueNode<Primitive>(node)) {
return;
}
buffer_ << "node" << node << "_" << id;
buffer_ << "[";
// set fontname
buffer_ << "fontname=\"Courier New\",";
// set label and shape
buffer_ << "shape=" << Shape(node) << ",";
if (node->isa<CNode>()) {
DrawCNode(this, node->cast<CNodePtr>());
} else if (node->isa<ValueNode>() && !IsValueNode<FuncGraph>(node)) {
DrawValueNode(this, node->cast<ValueNodePtr>());
} else {
buffer_ << "label=\"" << node->ToString() << "\",";
buffer_ << "style=filled,fillcolor=" << Color(node) << ",";
}
// Set URL for func graph
if (IsValueNode<FuncGraph>(node)) {
buffer_ << "URL=\"#cluster_" << GetValueNode(node) << "\",";
}
buffer_ << "]" << std::endl;
}
void ModelDigraph::Edge(AnfNodePtr start, AnfNodePtr end, int idx, int id_start) {
if (start == nullptr || end == nullptr) {
return;
}
Head(start, id_start);
buffer_ << "->";
Tail(end, idx);
buffer_ << "[arrowhead=vee,";
buffer_ << "]" << std::endl;
}
struct DrawerRegister {
DrawerRegister() {
FuncGraph::set_drawer(
[](const std::string &filename, const FuncGraphPtr &func_graph) { Draw(filename, func_graph); });
}
~DrawerRegister() = default;
} drawer_regsiter;
} // namespace draw
} // namespace mindspore