|
|
|
@ -82,10 +82,7 @@ class DefaultValueSetter {
|
|
|
|
|
public:
|
|
|
|
|
explicit DefaultValueSetter(T default_value)
|
|
|
|
|
: default_value_(default_value) {}
|
|
|
|
|
void operator()(T* value) const {
|
|
|
|
|
PADDLE_ENFORCE(value != nullptr, "Can not set default value to nullptr");
|
|
|
|
|
*value = default_value_;
|
|
|
|
|
}
|
|
|
|
|
void operator()(T& value) const { value = default_value_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T default_value_;
|
|
|
|
@ -202,7 +199,6 @@ struct ExtractAttribute<int64_t> {
|
|
|
|
|
template <typename T>
|
|
|
|
|
class TypedAttrChecker {
|
|
|
|
|
typedef std::function<void(T&)> ValueChecker;
|
|
|
|
|
typedef std::function<void(T*)> ValueSetter;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
explicit TypedAttrChecker(const std::string& attr_name)
|
|
|
|
@ -245,7 +241,7 @@ class TypedAttrChecker {
|
|
|
|
|
"Attribute '%s' is required!", attr_name_);
|
|
|
|
|
// default_value_setter_ has no more than one element
|
|
|
|
|
T val;
|
|
|
|
|
(default_value_setter_[0])(&val);
|
|
|
|
|
(default_value_setter_[0])(val);
|
|
|
|
|
attr_map[attr_name_] = val;
|
|
|
|
|
}
|
|
|
|
|
Attribute& attr = attr_map.at(attr_name_);
|
|
|
|
@ -259,7 +255,7 @@ class TypedAttrChecker {
|
|
|
|
|
private:
|
|
|
|
|
std::string attr_name_;
|
|
|
|
|
std::vector<ValueChecker> value_checkers_;
|
|
|
|
|
std::vector<ValueSetter> default_value_setter_;
|
|
|
|
|
std::vector<ValueChecker> default_value_setter_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// check whether op's all attributes fit their own limits
|
|
|
|
|