clean up and correctness check

bugfix/anakin-compile
Xin Pan 7 years ago
parent aa1085ddc5
commit ab72d28a5e

@ -75,7 +75,12 @@ can also fuse some `Graph`'s `Node`s.
class Pass { class Pass {
public: public:
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0; std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
// Some correctness check.
auto new_graph = ApplyImpl(std::move(graph));
// Some correctness check.
return new_graph;
}
// Get a reference to the attributed previously set. // Get a reference to the attributed previously set.
template <typename AttrType> template <typename AttrType>
@ -89,6 +94,9 @@ class Pass {
// should delete the attribute. // should delete the attribute.
template <typename AttrType> template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr); void SetNotOwned(const std::string &attr_name, AttrType *attr);
protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
}; };
// In my_pass.cc // In my_pass.cc

@ -31,8 +31,8 @@ class Scope;
namespace details { namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder { class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public: protected:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
private: private:

@ -18,7 +18,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
continue; continue;
@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar()); auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var); read_op->AddOutput(dep_var);
write_op->AddInput(dep_var); write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var); graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
} }
} }
} }
@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place, ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset]; auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[node->Name()]; auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node, ir::Node *new_node,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()]; auto &vars =
graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
size_t version = vars.size(); size_t version = vars.size();
auto var = auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place); new VarHandle(new_node, version, place_offset, new_node->Name(), place);
@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
} }
void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) { for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (!op->Outputs().empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar()); auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf); graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
op->AddOutput(dummy_leaf); op->AddOutput(dummy_leaf);
} }
} }

@ -39,15 +39,19 @@ namespace details {
typedef std::vector< typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>> std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars; GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars; typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is // all operators. NOTE that even we use a vector here, the operators is
// unordered. // unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps; typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice; typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
class SSAGraphBuilder : public ir::Pass { class SSAGraphBuilder : public ir::Pass {
public: public:

@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} }
}; };
for (auto &var_map : graph->Get<GraphVars>("vars")) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get()); insert_pending_var(version_pair.get());
@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} }
} }
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) { for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get()); insert_pending_var(var.get());
} }
for (auto &op : graph->Get<GraphOps>("ops")) { for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (op->Inputs().empty()) { if (op->Inputs().empty()) {
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
@ -87,4 +87,8 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} // namespace paddle } // namespace paddle
REGISTER_PASS(multi_device_check_pass, REGISTER_PASS(multi_device_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker); paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);

@ -23,8 +23,8 @@ namespace framework {
namespace details { namespace details {
class SSAGraghBuilderWithChecker : public SSAGraphBuilder { class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public: protected:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get())); PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph; return graph;

@ -22,7 +22,7 @@ namespace details {
template <typename Callback> template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) { static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) { for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &pair1 : each) { for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) { for (auto &pair2 : pair1.second) {
callback(*pair2); callback(*pair2);
@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
} }
} }
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) { for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
callback(*var); callback(*var);
} }
} }
@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
}); });
size_t op_id = 0; size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>("ops")) { for (auto &op : graph.Get<GraphOps>(kGraphOps)) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;

@ -36,8 +36,8 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
}; };
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder { class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public: protected:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override { std::unique_ptr<ir::Graph> graph) const override {
std::unique_ptr<std::ostream> fout( std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<const std::string>("debug_graphviz_path"))); new std::ofstream(Get<const std::string>("debug_graphviz_path")));

@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops; std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars // Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) { for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get()); InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
} }
} }
} }
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) { for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var.get()); InsertPendingVar(&pending_vars, &ready_vars, var.get());
} }
for (auto &op : graph_->Get<details::GraphOps>("ops")) { for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
if (op->Inputs().empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars; std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) { for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) { for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name); auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) { if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get()); fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());

@ -40,10 +40,14 @@ class Graph {
attr_dels_.clear(); attr_dels_.clear();
} }
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
}
template <typename AttrType> template <typename AttrType>
AttrType &Get(const std::string &attr_name) const { AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(), PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
"%s attr not registered for graph.", attr_name); attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name)); return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} }

@ -20,10 +20,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static const char kGraphVizPath[] = "graph_viz_path";
std::unique_ptr<ir::Graph> GraphVizPass::Apply( std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>("graph_viz_path"); const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path)); std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good()); PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout; std::ostream& sout = *fout;
@ -67,4 +68,5 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass); REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);

@ -28,8 +28,8 @@ namespace framework {
namespace ir { namespace ir {
class GraphVizPass : public Pass { class GraphVizPass : public Pass {
public: protected:
std::unique_ptr<ir::Graph> Apply( std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override; std::unique_ptr<ir::Graph> graph) const override;
}; };

@ -17,6 +17,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
"Required pass atrribute %s not registered.", attr);
}
for (const std::string& attr : required_graph_attrs_) {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not exist.",
attr);
}
auto applied_graph = ApplyImpl(std::move(graph));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
return applied_graph;
}
PassRegistry& PassRegistry::Instance() { PassRegistry& PassRegistry::Instance() {
static PassRegistry g_pass_info_map; static PassRegistry g_pass_info_map;
return g_pass_info_map; return g_pass_info_map;

@ -19,6 +19,7 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h" #include "paddle/fluid/platform/variant.h"
@ -26,6 +27,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename PassType>
struct PassRegistrar;
class Pass { class Pass {
public: public:
@ -40,7 +43,7 @@ class Pass {
attr_dels_.clear(); attr_dels_.clear();
} }
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0; std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
// Get a reference to the attributed previously set. // Get a reference to the attributed previously set.
template <typename AttrType> template <typename AttrType>
@ -69,7 +72,25 @@ class Pass {
attrs_[attr_name] = attr; attrs_[attr_name] = attr;
} }
protected:
virtual std::unique_ptr<Graph> ApplyImpl(
std::unique_ptr<Graph> graph) const = 0;
private: private:
template <typename PassType>
friend struct PassRegistrar;
void RegisterRequiredPassAttrs(const std::unordered_set<std::string> &attrs) {
required_pass_attrs_.insert(attrs.begin(), attrs.end());
}
void RegisterRequiredGraphAttrs(
const std::unordered_set<std::string> &attrs) {
required_graph_attrs_.insert(attrs.begin(), attrs.end());
}
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> attrs_; std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_; std::map<std::string, std::function<void(void)>> attr_dels_;
}; };
@ -119,10 +140,28 @@ struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) { explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type), PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type); "'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr<Pass> { PassRegistry::Instance().Insert(
return std::unique_ptr<Pass>(new PassType()); pass_type, [this]() -> std::unique_ptr<Pass> {
}); std::unique_ptr<Pass> pass(new PassType());
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
return pass;
});
} }
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
required_pass_attrs_.insert(attr);
return *this;
}
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
required_graph_attrs_.insert(attr);
return *this;
}
private:
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
}; };
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \ #define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
@ -132,16 +171,19 @@ struct PassRegistrar : public Registrar {
msg) msg)
// Register a new pass that can be applied on the IR. // Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \ #define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \ __reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \ "REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \ static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \ __pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \ int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \ __pass_registrar_##pass_type##__.Touch(); \
return 0; \ return 0; \
} } \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \ #define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \ STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \

@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices(
if (member_->executor_) { if (member_->executor_) {
auto &sharded_var_device = auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>( member_->executor_->Graph().Get<details::ShardedVarDevice>(
"sharded_var_device"); details::kShardedVarDevice);
if (sharded_var_device.find(var) != sharded_var_device.end()) { if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var); var_dev_id = sharded_var_device.at(var);
} }

Loading…
Cancel
Save