You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/ir/pass.h

232 lines
7.8 KiB

/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <map>
#include <string>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle {
namespace framework {
namespace ir {
template <typename PassType>
struct PassRegistrar;
class Pass {
public:
Pass() = default;
virtual ~Pass() {
for (auto &attr : attrs_) {
if (attr_dels_.find(attr.first) != attr_dels_.end()) {
attr_dels_[attr.first]();
}
}
attrs_.clear();
attr_dels_.clear();
}
std::string Type() const { return type_; }
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph) const;
// Get a reference to the attributed previously set.
template <typename AttrType>
AttrType &Get(const std::string &attr_name) const {
PADDLE_ENFORCE(attrs_.find(attr_name) != attrs_.end(),
"%s attr not registered for pass.", attr_name);
try {
return *boost::any_cast<AttrType *>(attrs_.at(attr_name));
} catch (boost::bad_any_cast &) {
PADDLE_THROW(
"Invalid attribute type of %s error, expected: %s, actual: %s",
attr_name, typeid(AttrType *).name(),
attrs_.at(attr_name).type().name());
}
}
bool Has(const std::string &attr_name) const {
return attrs_.count(attr_name) > 0;
}
void Erase(const std::string &attr_name) {
if (!Has(attr_name)) {
return;
}
if (attr_dels_.find(attr_name) != attr_dels_.end()) {
attr_dels_[attr_name]();
attr_dels_.erase(attr_name);
}
attrs_.erase(attr_name);
}
// Set a pointer to the attribute. Pass takes ownership of the attribute.
template <typename AttrType>
void Set(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass",
attr_name);
attrs_[attr_name] = attr;
attr_dels_[attr_name] = [attr, attr_name]() {
VLOG(3) << "deleting " << attr_name;
delete attr;
};
}
// Set a pointer to the attribute. Pass doesn't take ownership. Caller
// should delete the attribute.
template <typename AttrType>
void SetNotOwned(const std::string &attr_name, AttrType *attr) {
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass",
attr_name);
attrs_[attr_name] = attr;
}
protected:
virtual std::unique_ptr<Graph> ApplyImpl(std::unique_ptr<Graph> graph) const {
LOG(FATAL) << "Calling virtual Pass not implemented.";
return graph;
}
private:
template <typename PassType>
friend struct PassRegistrar;
void RegisterRequiredPassAttrs(const std::unordered_set<std::string> &attrs) {
required_pass_attrs_.insert(attrs.begin(), attrs.end());
}
void RegisterRequiredGraphAttrs(
const std::unordered_set<std::string> &attrs) {
required_graph_attrs_.insert(attrs.begin(), attrs.end());
}
void RegisterType(const std::string &type) { type_ = type; }
mutable bool applied_{false};
std::string type_;
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
std::map<std::string, boost::any> attrs_;
std::map<std::string, std::function<void(void)>> attr_dels_;
};
using PassCreator = std::function<std::unique_ptr<Pass>()>;
class Registrar {
public:
// In our design, various kinds of passes,
// have their corresponding registry and registrar. The action of
// registration is in the constructor of a global registrar variable, which
// are not used in the code that calls package framework, and would
// be removed from the generated binary file by the linker. To avoid such
// removal, we add Touch to all registrar classes and make USE_PASS macros to
// call this method. So, as long as the callee code calls USE_PASS, the global
// registrar variable won't be removed by the linker.
void Touch() {}
};
class PassRegistry {
public:
static PassRegistry &Instance();
bool Has(const std::string &pass_type) const {
return map_.find(pass_type) != map_.end();
}
void Insert(const std::string &pass_type, const PassCreator &pass_creator) {
PADDLE_ENFORCE(!Has(pass_type), "Pass %s has been registered", pass_type);
map_.insert({pass_type, pass_creator});
}
std::unique_ptr<Pass> Get(const std::string &pass_type) const {
PADDLE_ENFORCE(Has(pass_type), "Pass %s has not been registered",
pass_type);
return map_.at(pass_type)();
}
private:
PassRegistry() = default;
std::unordered_map<std::string, PassCreator> map_;
DISABLE_COPY_AND_ASSIGN(PassRegistry);
};
template <typename PassType>
struct PassRegistrar : public Registrar {
explicit PassRegistrar(const char *pass_type) {
PADDLE_ENFORCE(!PassRegistry::Instance().Has(pass_type),
"'%s' is registered more than once.", pass_type);
PassRegistry::Instance().Insert(
pass_type, [this, pass_type]() -> std::unique_ptr<Pass> {
std::unique_ptr<Pass> pass(new PassType());
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
pass->RegisterType(pass_type);
return pass;
});
}
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
required_pass_attrs_.insert(attr);
return *this;
}
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
required_graph_attrs_.insert(attr);
return *this;
}
private:
std::unordered_set<std::string> required_pass_attrs_;
std::unordered_set<std::string> required_graph_attrs_;
};
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)
// Register a new pass that can be applied on the IR.
#define REGISTER_PASS(pass_type, pass_class) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__reg_pass__##pass_type, \
"REGISTER_PASS must be called in global namespace"); \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
__pass_registrar_##pass_type##__(#pass_type); \
int TouchPassRegistrar_##pass_type() { \
__pass_registrar_##pass_type##__.Touch(); \
return 0; \
} \
static ::paddle::framework::ir::PassRegistrar<pass_class> \
&__pass_tmp_registrar_##pass_type##__ UNUSED = \
__pass_registrar_##pass_type##__
#define USE_PASS(pass_type) \
STATIC_ASSERT_PASS_GLOBAL_NAMESPACE( \
__use_pass_itself_##pass_type, \
"USE_PASS must be called in global namespace"); \
extern int TouchPassRegistrar_##pass_type(); \
static int use_pass_itself_##pass_type##_ UNUSED = \
TouchPassRegistrar_##pass_type()
} // namespace ir
} // namespace framework
} // namespace paddle