|
|
|
@ -19,7 +19,7 @@ limitations under the License. */
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include "paddle/framework/attr_checker.h"
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
|
#include "paddle/framework/grad_op_builder.h"
|
|
|
|
|
#include "paddle/framework/op_desc.pb.h"
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
@ -31,43 +31,6 @@ namespace framework {
|
|
|
|
|
struct AttrTypeHelper {
|
|
|
|
|
template <typename T>
|
|
|
|
|
static void SetAttrType(AttrProto* attr);
|
|
|
|
|
|
|
|
|
|
static Attribute GetAttrValue(const AttrDesc& attr_desc) {
|
|
|
|
|
switch (attr_desc.type()) {
|
|
|
|
|
case paddle::framework::AttrType::INT: {
|
|
|
|
|
return attr_desc.i();
|
|
|
|
|
}
|
|
|
|
|
case paddle::framework::AttrType::FLOAT: {
|
|
|
|
|
return attr_desc.f();
|
|
|
|
|
}
|
|
|
|
|
case paddle::framework::AttrType::STRING: {
|
|
|
|
|
return attr_desc.s();
|
|
|
|
|
}
|
|
|
|
|
case paddle::framework::AttrType::INTS: {
|
|
|
|
|
std::vector<int> val(attr_desc.ints_size());
|
|
|
|
|
for (int i = 0; i < attr_desc.ints_size(); ++i) {
|
|
|
|
|
val[i] = attr_desc.ints(i);
|
|
|
|
|
}
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
case paddle::framework::AttrType::FLOATS: {
|
|
|
|
|
std::vector<float> val(attr_desc.floats_size());
|
|
|
|
|
for (int i = 0; i < attr_desc.floats_size(); ++i) {
|
|
|
|
|
val[i] = attr_desc.floats(i);
|
|
|
|
|
}
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
case paddle::framework::AttrType::STRINGS: {
|
|
|
|
|
std::vector<std::string> val(attr_desc.strings_size());
|
|
|
|
|
for (int i = 0; i < attr_desc.strings_size(); ++i) {
|
|
|
|
|
val[i] = attr_desc.strings(i);
|
|
|
|
|
}
|
|
|
|
|
return val;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
|
|
|
|
|
return boost::blank();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// this class not only make proto but also init attribute checkers.
|
|
|
|
@ -136,7 +99,7 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
*attr->mutable_name() = name;
|
|
|
|
|
*attr->mutable_comment() = comment;
|
|
|
|
|
attr->set_generated(generated);
|
|
|
|
|
AttrTypeHelper::SetAttrType<T>(attr);
|
|
|
|
|
attr->set_type(AttrTypeID<T>());
|
|
|
|
|
return op_checker_->AddAttrChecker<T>(name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -297,7 +260,7 @@ class OpRegistry {
|
|
|
|
|
|
|
|
|
|
AttributeMap attrs;
|
|
|
|
|
for (auto& attr : op_desc.attrs()) {
|
|
|
|
|
attrs[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
|
|
|
|
|
attrs[attr.name()] = GetAttrValue(attr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return CreateOp(op_desc.type(), inputs, outputs, attrs);
|
|
|
|
|