feature/analysis node representation (#10522)
parent
8231960fb1
commit
de81ccb5cb
@ -1 +1,2 @@
|
||||
cc_library(dot SRCS dot.cc)
|
||||
cc_library(analysis SRCS dot.cc node.cc node.h)
|
||||
cc_test(test_node SRCS node_tester.cc DEPS analysis)
|
||||
|
@ -0,0 +1,23 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
enum class Device { CPU, GPU };
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/analysis/dot.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
class DotTester : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::vector<Dot::Attr> attrs({{"title", "hello"}});
|
||||
dot.reset(new Dot(attrs));
|
||||
dot->AddNode("a", {Dot::Attr{"shape", "box"}, Dot::Attr("color", "blue")});
|
||||
dot->AddNode("b", {});
|
||||
dot->AddNode("c", {});
|
||||
dot->AddEdge("a", "b", {});
|
||||
dot->AddEdge("b", "c", {});
|
||||
dot->AddEdge("a", "c", {});
|
||||
}
|
||||
|
||||
std::unique_ptr<Dot> dot;
|
||||
};
|
||||
|
||||
TEST_F(DotTester, Build) {
|
||||
auto codes = dot->Build();
|
||||
// Output the DOT language code, the generated codes are too long to compare
|
||||
// the string.
|
||||
//
|
||||
// The output is
|
||||
//
|
||||
// digraph G {
|
||||
// title="hello"
|
||||
// node_1
|
||||
// node_2
|
||||
// node_0[label="a" shape="box" color="blue"]
|
||||
// node_0->node_1
|
||||
// node_1->node_2
|
||||
// node_0->node_2
|
||||
// } // end G
|
||||
LOG(INFO) << '\n' << codes;
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,74 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
template <typename IteratorT>
|
||||
class iterator_range {
|
||||
IteratorT begin_, end_;
|
||||
|
||||
public:
|
||||
template <typename Container>
|
||||
explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
|
||||
|
||||
iterator_range(const IteratorT &begin, const IteratorT &end)
|
||||
: begin_(begin), end_(end) {}
|
||||
|
||||
const IteratorT &begin() const { return begin_; }
|
||||
const IteratorT &end() const { return end_; }
|
||||
};
|
||||
|
||||
/*
|
||||
* An registry helper class, with its records keeps the order they registers.
|
||||
*/
|
||||
template <typename T>
|
||||
class OrderedRegistry {
|
||||
public:
|
||||
T *Register(const std::string &name, T *x) {
|
||||
PADDLE_ENFORCE(!dic_.count(name));
|
||||
dic_[name] = data_.size();
|
||||
data_.emplace_back(std::unique_ptr<T>(x));
|
||||
return data_.back().get();
|
||||
}
|
||||
|
||||
T *Lookup(const std::string &name) {
|
||||
auto it = dic_.find(name);
|
||||
if (it == dic_.end()) return nullptr;
|
||||
return data_[it->second].get();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::unordered_map<std::string, int> dic_;
|
||||
std::vector<std::unique_ptr<T>> data_;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
|
||||
\
|
||||
type__(const type__ &) = delete; \
|
||||
\
|
||||
void operator=(const type__ &) = delete;
|
@ -0,0 +1,67 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/analysis/node.h"
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
std::vector<Dot::Attr> Value::dot_attrs() const {
|
||||
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
|
||||
Dot::Attr("shape", "box"),
|
||||
Dot::Attr("fillcolor", "red")});
|
||||
}
|
||||
|
||||
std::vector<Dot::Attr> Function::dot_attrs() const {
|
||||
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
|
||||
Dot::Attr("shape", "diamond"),
|
||||
Dot::Attr("fillcolor", "yellow")});
|
||||
}
|
||||
|
||||
Node *NodeMap::Create(Node::Type type) {
|
||||
switch (type) {
|
||||
case Node::Type::kFunction:
|
||||
nodes_.emplace_back(new Function);
|
||||
break;
|
||||
case Node::Type::kValue:
|
||||
nodes_.emplace_back(new Value);
|
||||
break;
|
||||
default:
|
||||
PADDLE_THROW("Not supported node type.");
|
||||
}
|
||||
nodes_.back()->id_ = size() - 1;
|
||||
return nodes_.back().get();
|
||||
}
|
||||
|
||||
Node *NodeMap::GetMutable(size_t id) {
|
||||
PADDLE_ENFORCE_GT(size(), id);
|
||||
return nodes_[id].get();
|
||||
}
|
||||
|
||||
const Node &NodeMap::Get(size_t id) const {
|
||||
PADDLE_ENFORCE_GT(size(), id);
|
||||
return *nodes_[id].get();
|
||||
}
|
||||
|
||||
void NodeMap::Delete(size_t id) {
|
||||
PADDLE_ENFORCE_LT(id, size());
|
||||
nodes_[id]->SetDeleted();
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,234 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
/*
|
||||
* This file defines the Node class and its subclasses. A Node is the basis
|
||||
* analysis element in a computation graph.
|
||||
* There are basically two kinds of nodes, the function node and value node.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/inference/analysis/device.h"
|
||||
#include "paddle/fluid/inference/analysis/dot.h"
|
||||
#include "paddle/fluid/inference/analysis/helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
class NodeMap;
|
||||
|
||||
/*
|
||||
* Node Representation.
|
||||
*
|
||||
* This is a very important class for analysis. It is the base class of all
|
||||
* nodes computed by a program that may be used as operands to other nodes.
|
||||
* Node is the super class of other important classes such as Function and
|
||||
* Value, some nodes can have a name.
|
||||
*/
|
||||
class Node {
|
||||
public:
|
||||
// Node type. NOTE the new node types should add here.
|
||||
enum class Type { kNone = -1, kFunction, kValue, kFunctionBlock };
|
||||
|
||||
Node() = default;
|
||||
|
||||
struct Attr;
|
||||
|
||||
// Cast to a subclass type, Function for example.
|
||||
template <typename Subclass>
|
||||
Subclass &As() {
|
||||
return *dynamic_cast<Subclass *>(this);
|
||||
}
|
||||
|
||||
// Formatted representation of this Node.
|
||||
virtual std::string repr() const {
|
||||
return name() + "(" + std::to_string(id()) + ")";
|
||||
}
|
||||
|
||||
// DOT node representation. One Node type can customize its own node
|
||||
// representation.
|
||||
virtual std::vector<Dot::Attr> dot_attrs() const {
|
||||
return std::vector<Dot::Attr>({Dot::Attr("style", "filled")});
|
||||
}
|
||||
|
||||
// Get an additional attribute and convert it to T data type. NOTE this will
|
||||
// silently create a new attribute if not exists.
|
||||
Attr &attr(const std::string &name) { return attrs_[name]; }
|
||||
|
||||
int id() const { return id_; }
|
||||
|
||||
bool deleted() const { return deleted_; }
|
||||
void SetDeleted() { deleted_ = true; }
|
||||
|
||||
void SetName(const std::string &name) { name_ = name; }
|
||||
const std::string &name() const { return name_; }
|
||||
|
||||
void SetType(Type type) { type_ = type; }
|
||||
Type type() const { return type_; }
|
||||
|
||||
void *extra_info() const { return extra_info_; }
|
||||
void SetExtraInfo(void *extra_info) { extra_info_ = extra_info; }
|
||||
|
||||
// Input links.
|
||||
std::vector<Node *> inlinks;
|
||||
// Output links.
|
||||
std::vector<Node *> outlinks;
|
||||
|
||||
// A helper class to maintain the status from Pass.
|
||||
// TODO(superjomn) add a checker here to ensure the T is primary.
|
||||
struct Attr {
|
||||
// NOTE T should be a primary type or a struct combined by several primary
|
||||
// types.
|
||||
// NOTE the STL containers should not use here.
|
||||
// Some usages
|
||||
// Attr attr;
|
||||
// T data;
|
||||
// attr.data.assign((char*)data, sizeof(data));
|
||||
|
||||
bool &Bool() { return As<bool>(); }
|
||||
float &Float() { return As<float>(); }
|
||||
int32_t &Int32() { return As<int32_t>(); }
|
||||
int64_t &Int64() { return As<int64_t>(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
T &As() {
|
||||
// init storage in the first usage.
|
||||
if (data_.empty()) {
|
||||
VLOG(4) << "resize data to " << sizeof(T);
|
||||
type_hash_ = typeid(T).hash_code();
|
||||
data_.resize(sizeof(T));
|
||||
}
|
||||
PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(), "type not matched");
|
||||
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
|
||||
return *reinterpret_cast<T *>(&data_[0]);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string data_;
|
||||
size_t type_hash_{std::numeric_limits<size_t>::max()};
|
||||
};
|
||||
|
||||
virtual ~Node() {}
|
||||
|
||||
friend class NodeMap;
|
||||
|
||||
PADDLE_DISALLOW_COPY_AND_ASSIGN(Node);
|
||||
|
||||
protected:
|
||||
// The id number not the name is a node's unique identifier in the computation
|
||||
// graph.
|
||||
int id_{-1};
|
||||
std::string name_;
|
||||
Type type_{Type::kNone};
|
||||
// Mark this node is deleted by some pass.
|
||||
bool deleted_{false};
|
||||
|
||||
void *extra_info_;
|
||||
|
||||
mutable std::unordered_map<std::string, Attr> attrs_;
|
||||
};
|
||||
|
||||
class Function;
|
||||
/*
|
||||
* Value represents a value node, it has some attributes including dims, data
|
||||
* type and so on.
|
||||
*/
|
||||
class Value : public Node {
|
||||
public:
|
||||
enum class DataType { kInt32, kInt64, kFloat32, kFloat64 };
|
||||
using Dims = std::vector<int>;
|
||||
|
||||
void SetDataType(DataType data_type) { data_type_ = data_type; }
|
||||
DataType data_type() const { return data_type_; }
|
||||
|
||||
void SetDims(const Dims &dims) { dims_ = dims; }
|
||||
const Dims &dims() const { return dims_; }
|
||||
|
||||
Device device() const { return device_; }
|
||||
void SetDevice(Device device) { device_ = device; }
|
||||
|
||||
std::vector<Dot::Attr> dot_attrs() const override;
|
||||
|
||||
PADDLE_DISALLOW_COPY_AND_ASSIGN(Value);
|
||||
|
||||
protected:
|
||||
Value() { SetType(Node::Type::kValue); }
|
||||
friend class NodeMap;
|
||||
|
||||
private:
|
||||
DataType data_type_;
|
||||
Dims dims_;
|
||||
Device device_;
|
||||
};
|
||||
|
||||
/*
|
||||
* Function represents any kind of executable concepts that takes several Values
|
||||
* as input, and outputs several Values.
|
||||
*/
|
||||
class Function : public Node {
|
||||
public:
|
||||
std::vector<Dot::Attr> dot_attrs() const override;
|
||||
|
||||
// Get the operator's type from Desc.
|
||||
const std::string &func_type() const { return func_type_; }
|
||||
// Set the operator's type.
|
||||
void SetFuncType(const std::string &func_type) { func_type_ = func_type; }
|
||||
|
||||
PADDLE_DISALLOW_COPY_AND_ASSIGN(Function);
|
||||
|
||||
protected:
|
||||
std::string func_type_;
|
||||
Function() { SetType(Node::Type::kFunction); }
|
||||
friend class NodeMap;
|
||||
};
|
||||
|
||||
/*
|
||||
* FunctionBlock is a Node that contains a sub-graph multiple Node.
|
||||
*/
|
||||
struct FunctionBlock : public Node {
|
||||
std::string repr() const override { return "block-" + std::to_string(id()); }
|
||||
std::vector<Node *> subgraph;
|
||||
};
|
||||
|
||||
class NodeMap {
|
||||
public:
|
||||
// Create a new node with type.
|
||||
Node *Create(Node::Type type);
|
||||
|
||||
// Get a node by its id.
|
||||
Node *GetMutable(size_t id);
|
||||
|
||||
const Node &Get(size_t id) const;
|
||||
|
||||
void Delete(size_t id);
|
||||
|
||||
const std::vector<std::unique_ptr<Node>> &nodes() { return nodes_; }
|
||||
|
||||
size_t size() const { return nodes_.size(); }
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<Node>> nodes_;
|
||||
std::unordered_map<std::string, Node *> map_;
|
||||
};
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
@ -0,0 +1,34 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/inference/analysis/node.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace analysis {
|
||||
|
||||
TEST(Node, Attr) {
|
||||
// Node is an abstract class, use Value instead for they share the same Attr
|
||||
// logic.
|
||||
NodeMap nodes;
|
||||
auto* node = nodes.Create(Node::Type::kValue);
|
||||
node->attr("v0").Int32() = 2008;
|
||||
ASSERT_EQ(node->attr("v0").Int32(), 2008);
|
||||
}
|
||||
|
||||
} // namespace analysis
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue