|
|
@ -19,6 +19,10 @@ limitations under the License. */
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace framework {
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
|
|
|
|
|
|
AttrType AttrTypeID<bool>() {
|
|
|
|
|
|
|
|
return BOOL;
|
|
|
|
|
|
|
|
}
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
AttrType AttrTypeID<int>() {
|
|
|
|
AttrType AttrTypeID<int>() {
|
|
|
|
return INT;
|
|
|
|
return INT;
|
|
|
@ -32,6 +36,10 @@ AttrType AttrTypeID<std::string>() {
|
|
|
|
return STRING;
|
|
|
|
return STRING;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
template <>
|
|
|
|
template <>
|
|
|
|
|
|
|
|
AttrType AttrTypeID<std::vector<bool>>() {
|
|
|
|
|
|
|
|
return BOOLS;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
template <>
|
|
|
|
AttrType AttrTypeID<std::vector<int>>() {
|
|
|
|
AttrType AttrTypeID<std::vector<int>>() {
|
|
|
|
return INTS;
|
|
|
|
return INTS;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -50,6 +58,9 @@ AttrType AttrTypeID<std::vector<std::pair<int, int>>>() {
|
|
|
|
|
|
|
|
|
|
|
|
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
|
|
|
|
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
|
|
|
|
switch (attr_desc.type()) {
|
|
|
|
switch (attr_desc.type()) {
|
|
|
|
|
|
|
|
case paddle::framework::AttrType::BOOL: {
|
|
|
|
|
|
|
|
return attr_desc.b();
|
|
|
|
|
|
|
|
}
|
|
|
|
case paddle::framework::AttrType::INT: {
|
|
|
|
case paddle::framework::AttrType::INT: {
|
|
|
|
return attr_desc.i();
|
|
|
|
return attr_desc.i();
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -59,6 +70,13 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
|
|
|
|
case paddle::framework::AttrType::STRING: {
|
|
|
|
case paddle::framework::AttrType::STRING: {
|
|
|
|
return attr_desc.s();
|
|
|
|
return attr_desc.s();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
case paddle::framework::AttrType::BOOLS: {
|
|
|
|
|
|
|
|
std::vector<bool> val(attr_desc.bools_size());
|
|
|
|
|
|
|
|
for (int i = 0; i < attr_desc.bools_size(); ++i) {
|
|
|
|
|
|
|
|
val[i] = attr_desc.bools(i);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return val;
|
|
|
|
|
|
|
|
}
|
|
|
|
case paddle::framework::AttrType::INTS: {
|
|
|
|
case paddle::framework::AttrType::INTS: {
|
|
|
|
std::vector<int> val(attr_desc.ints_size());
|
|
|
|
std::vector<int> val(attr_desc.ints_size());
|
|
|
|
for (int i = 0; i < attr_desc.ints_size(); ++i) {
|
|
|
|
for (int i = 0; i < attr_desc.ints_size(); ++i) {
|
|
|
|