|
|
|
@ -100,8 +100,14 @@ class Pass {
|
|
|
|
|
// Set a pointer to the attribute. Pass takes ownership of the attribute.
|
|
|
|
|
template <typename AttrType>
|
|
|
|
|
void Set(const std::string &attr_name, AttrType *attr) {
|
|
|
|
|
PADDLE_ENFORCE(attrs_.count(attr_name) == 0, "%s already set in the pass",
|
|
|
|
|
attr_name);
|
|
|
|
|
if (default_pass_attrs_.count(attr_name) == 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(attrs_.count(attr_name), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Attribute %s already set in the pass", attr_name));
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(3) << "Setting the attribute " << attr_name << " for the pass "
|
|
|
|
|
<< type_;
|
|
|
|
|
}
|
|
|
|
|
attrs_[attr_name] = attr;
|
|
|
|
|
attr_dels_[attr_name] = [attr, attr_name]() {
|
|
|
|
|
VLOG(3) << "deleting " << attr_name;
|
|
|
|
@ -140,11 +146,21 @@ class Pass {
|
|
|
|
|
required_graph_attrs_.insert(attrs.begin(), attrs.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Pass doesn't take ownership. PassRegistrar should delete default_attrs
|
|
|
|
|
void RegisterDefaultPassAttrs(
|
|
|
|
|
std::map<std::string, boost::any> default_attr_values) {
|
|
|
|
|
for (auto const &attr_name : default_attr_values) {
|
|
|
|
|
default_pass_attrs_.insert(attr_name.first);
|
|
|
|
|
}
|
|
|
|
|
attrs_.insert(default_attr_values.begin(), default_attr_values.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RegisterType(const std::string &type) { type_ = type; }
|
|
|
|
|
|
|
|
|
|
mutable bool applied_{false};
|
|
|
|
|
std::string type_;
|
|
|
|
|
std::unordered_set<std::string> required_pass_attrs_;
|
|
|
|
|
std::unordered_set<std::string> default_pass_attrs_;
|
|
|
|
|
std::unordered_set<std::string> required_graph_attrs_;
|
|
|
|
|
std::map<std::string, boost::any> attrs_;
|
|
|
|
|
std::map<std::string, std::function<void(void)>> attr_dels_;
|
|
|
|
@ -203,16 +219,38 @@ struct PassRegistrar : public Registrar {
|
|
|
|
|
std::unique_ptr<Pass> pass(new PassType());
|
|
|
|
|
pass->RegisterRequiredPassAttrs(this->required_pass_attrs_);
|
|
|
|
|
pass->RegisterRequiredGraphAttrs(this->required_graph_attrs_);
|
|
|
|
|
pass->RegisterDefaultPassAttrs(this->default_attr_values_);
|
|
|
|
|
pass->RegisterType(pass_type);
|
|
|
|
|
return pass;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~PassRegistrar() {
|
|
|
|
|
for (auto &attr : default_attr_values_) {
|
|
|
|
|
if (default_attr_dels_.find(attr.first) != default_attr_dels_.end()) {
|
|
|
|
|
default_attr_dels_[attr.first]();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
default_attr_values_.clear();
|
|
|
|
|
default_attr_dels_.clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PassRegistrar<PassType> &RequirePassAttr(const std::string &attr) {
|
|
|
|
|
required_pass_attrs_.insert(attr);
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// PassRegistrar takes ownership of default_attr_value
|
|
|
|
|
template <typename AttrType>
|
|
|
|
|
PassRegistrar<PassType> &DefaultPassAttr(const std::string &attr,
|
|
|
|
|
AttrType &&default_attr_value) {
|
|
|
|
|
default_attr_values_[attr] = default_attr_value;
|
|
|
|
|
default_attr_dels_[attr] = [default_attr_value, attr]() {
|
|
|
|
|
delete default_attr_value;
|
|
|
|
|
};
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PassRegistrar<PassType> &RequireGraphAttr(const std::string &attr) {
|
|
|
|
|
required_graph_attrs_.insert(attr);
|
|
|
|
|
return *this;
|
|
|
|
@ -221,6 +259,8 @@ struct PassRegistrar : public Registrar {
|
|
|
|
|
private:
|
|
|
|
|
std::unordered_set<std::string> required_pass_attrs_;
|
|
|
|
|
std::unordered_set<std::string> required_graph_attrs_;
|
|
|
|
|
std::map<std::string, boost::any> default_attr_values_;
|
|
|
|
|
std::map<std::string, std::function<void(void)>> default_attr_dels_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define STATIC_ASSERT_PASS_GLOBAL_NAMESPACE(uniq_name, msg) \
|
|
|
|
|