|
|
|
@ -133,6 +133,32 @@ struct ExtractAttribute<std::vector<int64_t>> {
|
|
|
|
|
const std::string& attr_name_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
struct ExtractAttribute<float> {
|
|
|
|
|
explicit ExtractAttribute(const std::string& attr_name)
|
|
|
|
|
: attr_name_(attr_name) {}
|
|
|
|
|
|
|
|
|
|
float* operator()(Attribute& attr) const {
|
|
|
|
|
if (attr.type() == typeid(int)) { // NOLINT
|
|
|
|
|
int val = boost::get<int>(attr);
|
|
|
|
|
attr = static_cast<float>(val);
|
|
|
|
|
} else if (attr.type() == typeid(int64_t)) { // NOLINT
|
|
|
|
|
int64_t val = boost::get<int64_t>(attr);
|
|
|
|
|
attr = static_cast<float>(val);
|
|
|
|
|
}
|
|
|
|
|
float* attr_value = nullptr;
|
|
|
|
|
try {
|
|
|
|
|
attr_value = &boost::get<float>(attr);
|
|
|
|
|
} catch (boost::bad_get& bad_get) {
|
|
|
|
|
PADDLE_THROW("Cannot get attribute %s by type float, its type is %s",
|
|
|
|
|
attr_name_, paddle::platform::demangle(attr.type().name()));
|
|
|
|
|
}
|
|
|
|
|
return attr_value;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::string& attr_name_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
inline proto::AttrType AttrTypeID() {
|
|
|
|
|
Attribute tmp = T();
|
|
|
|
|