|
|
|
@ -19,6 +19,7 @@ limitations under the License. */
|
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/graph_helper.h"
|
|
|
|
|
#include "paddle/fluid/framework/ir/node.h"
|
|
|
|
|
#include "paddle/fluid/framework/program_desc.h"
|
|
|
|
|
#include "paddle/fluid/platform/variant.h"
|
|
|
|
@ -26,6 +27,8 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
namespace ir {
|
|
|
|
|
template <typename PassType>
|
|
|
|
|
struct PassRegistrar;
|
|
|
|
|
|
|
|
|
|
class Pass {
|
|
|
|
|
public:
|
|
|
|
@ -40,7 +43,7 @@ class Pass {
|
|
|
|
|
attr_dels_.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const = 0;
|
|
|
|
|
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
|
|
|
|
|
|
|
|
|
|
// Get a reference to the attributed previously set.
|
|
|
|
|
template <typename AttrType>
|
|
|
|
@ -69,7 +72,25 @@ class Pass {
|
|
|
|
|
attrs_[attr_name] = attr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual std::unique_ptr<Graph> ApplyImpl(
|
|
|
|
|
std::unique_ptr<Graph> graph) const = 0;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
template <typename PassType>
|
|
|
|
|
friend struct PassRegistrar;
|
|
|
|
|
|
|
|
|
|
void RegisterRequiredPassAttrs(const std::unordered_set<std::string> &attrs) {
|
|
|
|
|
required_pass_attrs_.insert(attrs.begin(), attrs.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RegisterRequiredGraphAttrs(
|
|
|
|
|
const std::unordered_set<std::string> &attrs) {
|
|
|
|
|
required_graph_attrs_.insert(attrs.begin(), attrs.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::string> required_pass_attrs_;
|
|
|
|
|
std::unordered_set<std::string> required_graph_attrs_;
|
|
|
|
|
std::map<std::string, boost::any> attrs_;
|
|
|
|
|
std::map<std::string, std::function<void(void)>> attr_dels_;
|
|
|
|
|
};
|
|
|
|
@ -119,10 +140,28 @@ struct PassRegistrar : public Registrar {
|
|
|
|
|
explicit PassRegistrar(const char *pass_type) {
|
|
|
|
|
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
|
|
|
|
|
"'%s' is registered more than once.", pass_type);
|
|
|
|
|
PassRegistry::Instance().Insert(pass_type, []() -> std::unique_ptr<Pass> {
|
|
|
|
|
return std::unique_ptr<Pass>(new PassType());
|
|
|
|
|
});
|
|
|
|
|
PassRegistry::Instance().Insert(
|
|
|
|
|
pass_type, [this]() -> std::unique_ptr<Pass> {
|
|
|
|
|
std::unique_ptr<Pass> pass(new PassType());
|
|
|
|
|
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
|
|
|
|
|
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
|
|
|
|
|
return pass;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
|
|
|
|
|
required_pass_attrs_.insert(attr);
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
|
|
|
|
|
required_graph_attrs_.insert(attr);
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
std::unordered_set<std::string> required_pass_attrs_;
|
|
|
|
|
std::unordered_set<std::string> required_graph_attrs_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
|
|
|
|
@ -132,16 +171,19 @@ struct PassRegistrar : public Registrar {
|
|
|
|
|
msg)
|
|
|
|
|
|
|
|
|
|
// Register a new pass that can be applied on the IR.
|
|
|
|
|
#define REGISTER_PASS(pass_type, pass_class) \
|
|
|
|
|
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_pass__##pass_type, \
|
|
|
|
|
"REGISTER_PASS must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::ir::PassRegistrar<pass_class> \
|
|
|
|
|
__pass_registrar_##pass_type##__(#pass_type); \
|
|
|
|
|
int TouchPassRegistrar_##pass_type() { \
|
|
|
|
|
__pass_registrar_##pass_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
#define REGISTER_PASS(pass_type, pass_class) \
|
|
|
|
|
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_pass__##pass_type, \
|
|
|
|
|
"REGISTER_PASS must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::ir::PassRegistrar<pass_class> \
|
|
|
|
|
__pass_registrar_##pass_type##__(#pass_type); \
|
|
|
|
|
int TouchPassRegistrar_##pass_type() { \
|
|
|
|
|
__pass_registrar_##pass_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
} \
|
|
|
|
|
static ::paddle::framework::ir::PassRegistrar<pass_class> \
|
|
|
|
|
&__pass_tmp_registrar_##pass_type##__ __attribute__((unused)) = \
|
|
|
|
|
__pass_registrar_##pass_type##__
|
|
|
|
|
|
|
|
|
|
#define USE_PASS(pass_type) \
|
|
|
|
|
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
|
|
|
|
|