|
|
|
@ -44,8 +44,11 @@ template <template <class...> class V, class... Ts>
|
|
|
|
|
struct variant_caster<V<Ts...>> {
|
|
|
|
|
using Type = V<Ts...>;
|
|
|
|
|
|
|
|
|
|
template <class T>
|
|
|
|
|
bool try_load(handle src, bool convert) {
|
|
|
|
|
template <typename T>
|
|
|
|
|
typename std::enable_if<
|
|
|
|
|
!std::is_same<T, boost::detail::variant::void_>::value,
|
|
|
|
|
bool>::type
|
|
|
|
|
try_load(handle src, bool convert) {
|
|
|
|
|
auto caster = make_caster<T>();
|
|
|
|
|
if (!load_success_ && caster.load(src, convert)) {
|
|
|
|
|
load_success_ = true;
|
|
|
|
@ -55,6 +58,13 @@ struct variant_caster<V<Ts...>> {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
|
|
|
|
|
bool>::type
|
|
|
|
|
try_load(handle src, bool convert) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool load(handle src, bool convert) {
|
|
|
|
|
auto unused = {false, try_load<Ts>(src, convert)...};
|
|
|
|
|
(void)(unused);
|
|
|
|
@ -210,6 +220,45 @@ public:
|
|
|
|
|
|
|
|
|
|
std::string DebugString() { return this->Proto()->DebugString(); }
|
|
|
|
|
|
|
|
|
|
bool HasAttr(const std::string &name) const {
|
|
|
|
|
return attrs_.find(name) != attrs_.end();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::AttrType GetAttrType(const std::string &name) const {
|
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
|
return static_cast<framework::AttrType>(it->second.which() - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> AttrNames() const {
|
|
|
|
|
std::vector<std::string> retv;
|
|
|
|
|
retv.reserve(attrs_.size());
|
|
|
|
|
for (auto &attr : attrs_) {
|
|
|
|
|
retv.push_back(attr.first);
|
|
|
|
|
}
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetAttr(const std::string &name, const Attribute &v) {
|
|
|
|
|
this->attrs_[name] = v;
|
|
|
|
|
need_update_ = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetBlockAttr(const std::string &name, BlockDescBind &block);
|
|
|
|
|
|
|
|
|
|
Attribute GetAttr(const std::string &name) const {
|
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GetBlockAttr(const std::string &name) const {
|
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
|
return boost::get<BlockDesc *>(it->second)->idx();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
struct SetAttrDescVisitor : public boost::static_visitor<void> {
|
|
|
|
|
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
|
|
|
|
|
mutable OpDesc::Attr *attr_;
|
|
|
|
@ -265,49 +314,13 @@ public:
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool HasAttr(const std::string &name) const {
|
|
|
|
|
return attrs_.find(name) != attrs_.end();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::AttrType GetAttrType(const std::string &name) const {
|
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
|
return static_cast<framework::AttrType>(it->second.which() - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> AttrNames() const {
|
|
|
|
|
std::vector<std::string> retv;
|
|
|
|
|
retv.reserve(attrs_.size());
|
|
|
|
|
for (auto &attr : attrs_) {
|
|
|
|
|
retv.push_back(attr.first);
|
|
|
|
|
}
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetAttr(const std::string &name, const Attribute &v) {
|
|
|
|
|
this->attrs_[name] = v;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetBlockAttr(const std::string &name, BlockDescBind &block);
|
|
|
|
|
|
|
|
|
|
Attribute GetAttr(const std::string &name) const {
|
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int GetBlockAttr(const std::string &name) const {
|
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
|
return boost::get<BlockDesc *>(it->second)->idx();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
OpDesc op_desc_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::string>> inputs_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::string>> outputs_;
|
|
|
|
|
std::unordered_map<std::string, Attribute> attrs_;
|
|
|
|
|
|
|
|
|
|
// need_update_ indicate there some local changes not be synchronized. If
|
|
|
|
|
// local changes should be synchronized, need_update_ should be set to true.
|
|
|
|
|
bool need_update_{false};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|