|
|
|
@ -165,7 +165,7 @@ template <typename T>
|
|
|
|
|
class GreaterThanChecker {
|
|
|
|
|
public:
|
|
|
|
|
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
|
|
|
|
|
void operator()(T& value) const {
|
|
|
|
|
void operator()(const T& value) const {
|
|
|
|
|
PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -177,7 +177,7 @@ template <typename T>
|
|
|
|
|
class EqualGreaterThanChecker {
|
|
|
|
|
public:
|
|
|
|
|
explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
|
|
|
|
|
void operator()(T& value) const {
|
|
|
|
|
void operator()(const T& value) const {
|
|
|
|
|
PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -193,7 +193,7 @@ class DefaultValueSetter {
|
|
|
|
|
public:
|
|
|
|
|
explicit DefaultValueSetter(T default_value)
|
|
|
|
|
: default_value_(default_value) {}
|
|
|
|
|
void operator()(T& value) const { value = default_value_; } // NOLINT
|
|
|
|
|
void operator()(T* value) const { *value = default_value_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T default_value_;
|
|
|
|
@ -203,7 +203,7 @@ template <typename T>
|
|
|
|
|
class EnumInContainer {
|
|
|
|
|
public:
|
|
|
|
|
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
|
|
|
|
|
void operator()(T& val) const {
|
|
|
|
|
void operator()(const T& val) const {
|
|
|
|
|
PADDLE_ENFORCE(container_.find(val) != container_.end(),
|
|
|
|
|
"Value %s is not in enum container %s", val,
|
|
|
|
|
ContainerDebugString());
|
|
|
|
@ -232,7 +232,8 @@ class EnumInContainer {
|
|
|
|
|
// an attribute can have more than one limits
|
|
|
|
|
template <typename T>
|
|
|
|
|
class TypedAttrChecker {
|
|
|
|
|
typedef std::function<void(T&)> ValueChecker;
|
|
|
|
|
typedef std::function<void(T*)> DefaultValueChecker;
|
|
|
|
|
typedef std::function<void(const T&)> ValueChecker;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
explicit TypedAttrChecker(const std::string& attr_name)
|
|
|
|
@ -268,17 +269,17 @@ class TypedAttrChecker {
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void operator()(AttributeMap& attr_map) const { // NOLINT
|
|
|
|
|
if (!attr_map.count(attr_name_)) {
|
|
|
|
|
void operator()(AttributeMap* attr_map) const {
|
|
|
|
|
if (!attr_map->count(attr_name_)) {
|
|
|
|
|
// user do not set this attr
|
|
|
|
|
PADDLE_ENFORCE(!default_value_setter_.empty(),
|
|
|
|
|
"Attribute '%s' is required!", attr_name_);
|
|
|
|
|
// default_value_setter_ has no more than one element
|
|
|
|
|
T val;
|
|
|
|
|
(default_value_setter_[0])(val);
|
|
|
|
|
attr_map[attr_name_] = val;
|
|
|
|
|
(default_value_setter_[0])(&val);
|
|
|
|
|
(*attr_map)[attr_name_] = val;
|
|
|
|
|
}
|
|
|
|
|
Attribute& attr = attr_map.at(attr_name_);
|
|
|
|
|
Attribute& attr = attr_map->at(attr_name_);
|
|
|
|
|
ExtractAttribute<T> extract_attr(attr_name_);
|
|
|
|
|
T* attr_value = extract_attr(attr);
|
|
|
|
|
for (const auto& checker : value_checkers_) {
|
|
|
|
@ -289,12 +290,12 @@ class TypedAttrChecker {
|
|
|
|
|
private:
|
|
|
|
|
std::string attr_name_;
|
|
|
|
|
std::vector<ValueChecker> value_checkers_;
|
|
|
|
|
std::vector<ValueChecker> default_value_setter_;
|
|
|
|
|
std::vector<DefaultValueChecker> default_value_setter_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// check whether op's all attributes fit their own limits
|
|
|
|
|
class OpAttrChecker {
|
|
|
|
|
typedef std::function<void(AttributeMap&)> AttrChecker;
|
|
|
|
|
typedef std::function<void(AttributeMap*)> AttrChecker;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename T>
|
|
|
|
@ -304,7 +305,7 @@ class OpAttrChecker {
|
|
|
|
|
return *(checker.target<TypedAttrChecker<T>>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Check(AttributeMap& attr_map) const { // NOLINT
|
|
|
|
|
void Check(AttributeMap* attr_map) const {
|
|
|
|
|
for (const auto& checker : attr_checkers_) {
|
|
|
|
|
checker(attr_map);
|
|
|
|
|
}
|
|
|
|
|