|
|
|
@ -120,6 +120,57 @@ class EnumInContainer {
|
|
|
|
|
std::unordered_set<T> container_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ExtractAttribute {
|
|
|
|
|
explicit ExtractAttribute(const std::string& attr_name)
|
|
|
|
|
: attr_name_(attr_name) {}
|
|
|
|
|
|
|
|
|
|
T* operator()(Attribute& attr) const {
|
|
|
|
|
T* attr_value = nullptr;
|
|
|
|
|
try {
|
|
|
|
|
attr_value = &boost::get<T>(attr);
|
|
|
|
|
} catch (boost::bad_get& bad_get) {
|
|
|
|
|
PADDLE_THROW("Cannot get attribute %s by type %s, its type is %s",
|
|
|
|
|
attr_name_, typeid(T).name(), attr.type().name());
|
|
|
|
|
}
|
|
|
|
|
return attr_value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::string& attr_name_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// special handle bool
|
|
|
|
|
// FIXME(yuyang18): Currently we cast bool into int in python binding. It is
|
|
|
|
|
// hard to change the logic there. In another way, we should correct handle
|
|
|
|
|
// if the user set `some_flag=1`.
|
|
|
|
|
//
|
|
|
|
|
// FIX ME anytime if there is a better solution.
|
|
|
|
|
template <>
|
|
|
|
|
struct ExtractAttribute<bool> {
|
|
|
|
|
explicit ExtractAttribute(const std::string& attr_name)
|
|
|
|
|
: attr_name_(attr_name) {}
|
|
|
|
|
|
|
|
|
|
bool* operator()(Attribute& attr) const {
|
|
|
|
|
if (attr.type() == typeid(int)) { // NOLINT
|
|
|
|
|
int val = boost::get<int>(attr);
|
|
|
|
|
attr = static_cast<bool>(val);
|
|
|
|
|
} else if (attr.type() == typeid(float)) { // NOLINT
|
|
|
|
|
float val = boost::get<float>(attr);
|
|
|
|
|
attr = static_cast<bool>(val);
|
|
|
|
|
}
|
|
|
|
|
bool* attr_value = nullptr;
|
|
|
|
|
try {
|
|
|
|
|
attr_value = &boost::get<bool>(attr);
|
|
|
|
|
} catch (boost::bad_get& bad_get) {
|
|
|
|
|
PADDLE_THROW("Cannot get attribute %s by type bool, its type is %s",
|
|
|
|
|
attr_name_, attr.type().name());
|
|
|
|
|
}
|
|
|
|
|
return attr_value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::string& attr_name_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// check whether a certain attribute fit its limits
|
|
|
|
|
// an attribute can have more than one limits
|
|
|
|
|
template <typename T>
|
|
|
|
@ -171,9 +222,10 @@ class TypedAttrChecker {
|
|
|
|
|
attr_map[attr_name_] = val;
|
|
|
|
|
}
|
|
|
|
|
Attribute& attr = attr_map.at(attr_name_);
|
|
|
|
|
T& attr_value = boost::get<T>(attr);
|
|
|
|
|
ExtractAttribute<T> extract_attr(attr_name_);
|
|
|
|
|
T* attr_value = extract_attr(attr);
|
|
|
|
|
for (const auto& checker : value_checkers_) {
|
|
|
|
|
checker(attr_value);
|
|
|
|
|
checker(*attr_value);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|