clean up and correctness check

bugfix/anakin-compile
Xin Pan 7 years ago
parent aa1085ddc5
commit ab72d28a5e

@ -75,7 +75,12 @@ can also fuse some `Graph`'s `Node`s.
class Pass {
public:
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const {
// Some correctness check.
auto new_graph = ApplyImpl(std::move(graph));
// Some correctness check.
return new_graph;
}
// Get a reference to the attributed previously set.
template <typename AttrType>
@ -89,6 +94,9 @@ class Pass {
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr);
protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const = 0;
};
// In my_pass.cc

@ -31,8 +31,8 @@ class Scope;
namespace details {
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
private:

@ -18,7 +18,7 @@ namespace paddle {
namespace framework {
namespace details {
void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
continue;
@ -50,7 +50,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
auto *dep_var = new DummyVarHandle(graph->CreateControlDepVar());
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->Get<GraphDepVars>("dep_vars").emplace(dep_var);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dep_var);
}
}
}
@ -60,7 +60,7 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) {
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
ir::Graph *graph, ir::Node *node, const platform::Place &place,
size_t place_offset) {
auto &var_holders = graph->Get<GraphVars>("vars")[place_offset];
auto &var_holders = graph->Get<GraphVars>(kGraphVars)[place_offset];
auto &var_holder = var_holders[node->Name()];
VarHandle *var = nullptr;
if (var_holder.empty()) {
@ -83,7 +83,8 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
ir::Node *new_node,
const platform::Place &place,
size_t place_offset) {
auto &vars = graph->Get<GraphVars>("vars")[place_offset][new_node->Name()];
auto &vars =
graph->Get<GraphVars>(kGraphVars)[place_offset][new_node->Name()];
size_t version = vars.size();
auto var =
new VarHandle(new_node, version, place_offset, new_node->Name(), place);
@ -92,12 +93,12 @@ void SSAGraphBuilder::CreateOpOutput(ir::Graph *graph, OpHandleBase *op_handle,
}
void SSAGraphBuilder::AddOutputToLeafOps(ir::Graph *graph) {
for (auto &op : graph->Get<GraphOps>("ops")) {
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (!op->Outputs().empty()) {
continue;
}
auto *dummy_leaf = new DummyVarHandle(graph->CreateControlDepVar());
graph->Get<GraphDepVars>("dep_vars").emplace(dummy_leaf);
graph->Get<GraphDepVars>(kGraphDepVars).emplace(dummy_leaf);
op->AddOutput(dummy_leaf);
}
}

@ -39,15 +39,19 @@ namespace details {
typedef std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
GraphVars;
const char kGraphVars[] = "vars";
// aux variables to represent dependency. Useful to resolve data hazard.
typedef std::unordered_set<std::unique_ptr<VarHandleBase>> GraphDepVars;
const char kGraphDepVars[] = "dep_vars";
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
class SSAGraphBuilder : public ir::Pass {
public:

@ -33,7 +33,7 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
};
for (auto &var_map : graph->Get<GraphVars>("vars")) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
insert_pending_var(version_pair.get());
@ -41,11 +41,11 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
}
}
for (auto &var : graph->Get<GraphDepVars>("dep_vars")) {
for (auto &var : graph->Get<GraphDepVars>(kGraphDepVars)) {
insert_pending_var(var.get());
}
for (auto &op : graph->Get<GraphOps>("ops")) {
for (auto &op : graph->Get<GraphOps>(kGraphOps)) {
if (op->Inputs().empty()) {
ready_ops.insert(op.get());
} else {
@ -87,4 +87,8 @@ bool SSAGraghBuilderWithChecker::IsValidGraph(const ir::Graph *graph) const {
} // namespace paddle
REGISTER_PASS(multi_device_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker);
paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);

@ -23,8 +23,8 @@ namespace framework {
namespace details {
class SSAGraghBuilderWithChecker : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
PADDLE_ENFORCE(IsValidGraph(graph.get()));
return graph;

@ -22,7 +22,7 @@ namespace details {
template <typename Callback>
static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
for (auto &each : graph.Get<GraphVars>("vars")) {
for (auto &each : graph.Get<GraphVars>(kGraphVars)) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
@ -30,7 +30,7 @@ static inline void IterAllVar(const ir::Graph &graph, Callback callback) {
}
}
for (auto &var : graph.Get<GraphDepVars>("dep_vars")) {
for (auto &var : graph.Get<GraphDepVars>(kGraphDepVars)) {
callback(*var);
}
}
@ -61,7 +61,7 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
});
size_t op_id = 0;
for (auto &op : graph.Get<GraphOps>("ops")) {
for (auto &op : graph.Get<GraphOps>(kGraphOps)) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;

@ -36,8 +36,8 @@ class GraphvizSSAGraphPrinter : public SSAGraphPrinter {
};
class SSAGraghBuilderWithPrinter : public SSAGraphBuilder {
public:
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override {
std::unique_ptr<std::ostream> fout(
new std::ofstream(Get<const std::string>("debug_graphviz_path")));

@ -45,18 +45,18 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
std::unordered_set<OpHandleBase *> delayed_ops;
// Transform SSAGraph to pending_ops & pending_vars
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
InsertPendingVar(&pending_vars, &ready_vars, version_pair.get());
}
}
}
for (auto &var : graph_->Get<details::GraphDepVars>("dep_vars")) {
for (auto &var : graph_->Get<details::GraphDepVars>(details::kGraphDepVars)) {
InsertPendingVar(&pending_vars, &ready_vars, var.get());
}
for (auto &op : graph_->Get<details::GraphOps>("ops")) {
for (auto &op : graph_->Get<details::GraphOps>(details::kGraphOps)) {
if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get());
} else {
@ -162,7 +162,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps(
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : graph_->Get<details::GraphVars>("vars")) {
for (auto &var_map : graph_->Get<details::GraphVars>(details::kGraphVars)) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(it->second.rbegin()->get());

@ -40,10 +40,14 @@ class Graph {
attr_dels_.clear();
}
bool Has(const std::string &attr_name) const {
return attrs_.find(attr_name) != attrs_.end();
}
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for graph.", attr_name);
PADDLE_ENFORCE(Has(attr_name), "%s attr not registered for graph.",
attr_name);
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
}

@ -20,10 +20,11 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
static const char kGraphVizPath[] = "graph_viz_path";
std::unique_ptr<ir::Graph> GraphVizPass::Apply(
std::unique_ptr<ir::Graph> GraphVizPass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
const std::string graph_viz_path = Get<std::string>("graph_viz_path");
const std::string graph_viz_path = Get<std::string>(kGraphVizPath);
std::unique_ptr<std::ostream> fout(new std::ofstream(graph_viz_path));
PADDLE_ENFORCE(fout->good());
std::ostream& sout = *fout;
@ -67,4 +68,5 @@ std::unique_ptr<ir::Graph> GraphVizPass::Apply(
} // namespace framework
} // namespace paddle
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass);
REGISTER_PASS(graph_viz_pass, paddle::framework::ir::GraphVizPass)
.RequirePassAttr(paddle::framework::ir::kGraphVizPath);

@ -28,8 +28,8 @@ namespace framework {
namespace ir {
class GraphVizPass : public Pass {
public:
std::unique_ptr<ir::Graph> Apply(
protected:
std::unique_ptr<ir::Graph> ApplyImpl(
std::unique_ptr<ir::Graph> graph) const override;
};

@ -17,6 +17,22 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
std::unique_ptr<Graph> Pass::Apply(std::unique_ptr<Graph> graph) const {
for (const std::string& attr : required_pass_attrs_) {
PADDLE_ENFORCE(attrs_.find(attr) != attrs_.end(),
"Required pass atrribute %s not registered.", attr);
}
for (const std::string& attr : required_graph_attrs_) {
PADDLE_ENFORCE(graph->Has(attr), "Required graph atrribute %s not exist.",
attr);
}
auto applied_graph = ApplyImpl(std::move(graph));
// TODO(panyx0718): Add more verifications.
PADDLE_ENFORCE(!HasCircle(*applied_graph),
"Illegal Pass. Generated graph shouldn't has cycle.");
return applied_graph;
}
PassRegistry& PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;

@ -19,6 +19,7 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
@ -26,6 +27,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
namespace ir {
template <typename PassType>
struct PassRegistrar;
class Pass {
public:
@ -40,7 +43,7 @@ class Pass {
attr_dels_.clear();
}
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
// Get a reference to the attributed previously set.
template <typename AttrType>
@ -69,7 +72,25 @@ class Pass {
attrs_[attr_name] = attr;
}
protected:
virtual std::unique_ptr<Graph> ApplyImpl(
std::unique_ptr<Graph> graph) const = 0;
private:
template <typename PassType>
friend struct PassRegistrar;
void RegisterRequiredPassAttrs(const std::unordered_set<std::string> &attrs) {
required_pass_attrs_.insert(attrs.begin(), attrs.end());
}
void RegisterRequiredGraphAttrs(
const std::unordered_set<std::string> &attrs) {
required_graph_attrs_.insert(attrs.begin(), attrs.end());
}
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
@ -119,10 +140,28 @@ struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr<Pass> {
return std::unique_ptr<Pass>(new PassType());
});
PassRegistry::Instance().Insert(
pass_type, [this]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> pass(new PassType());
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
return pass;
});
}
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
required_pass_attrs_.insert(attr);
return *this;
}
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
required_graph_attrs_.insert(attr);
return *this;
}
private:
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
@ -132,16 +171,19 @@ struct PassRegistrar : public Registrar {
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
}
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \

@ -213,7 +213,7 @@ void ParallelExecutor::BCastParamsToDevices(
if (member_->executor_) {
auto &sharded_var_device =
member_->executor_->Graph().Get<details::ShardedVarDevice>(
"sharded_var_device");
details::kShardedVarDevice);
if (sharded_var_device.find(var) != sharded_var_device.end()) {
var_dev_id = sharded_var_device.at(var);
}

Loading…
Cancel
Save