|
|
|
@ -22,6 +22,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
|
#include "paddle/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/framework/grad_op_builder.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -127,7 +128,7 @@ class OpRegistry {
|
|
|
|
|
static void RegisterOp(const std::string& op_type) {
|
|
|
|
|
op_creators()[op_type] = [] { return new OpType; };
|
|
|
|
|
OpAttrChecker& op_checker = op_checkers()[op_type];
|
|
|
|
|
OpProto& op_proto = protos()[op_type];
|
|
|
|
|
OpProto& op_proto = OpProtos()[op_type];
|
|
|
|
|
auto maker = ProtoMakerType(&op_proto, &op_checker);
|
|
|
|
|
maker.Validate();
|
|
|
|
|
*op_proto.mutable_type() = op_type;
|
|
|
|
@ -135,17 +136,6 @@ class OpRegistry {
|
|
|
|
|
op_proto.IsInitialized(),
|
|
|
|
|
"Fail to initialize %s's OpProto, because %s is not initialized",
|
|
|
|
|
op_type, op_proto.InitializationErrorString());
|
|
|
|
|
|
|
|
|
|
VarIndexMaps()[op_type].reset(new VarIndexMap());
|
|
|
|
|
auto& varmap = *VarIndexMaps()[op_type];
|
|
|
|
|
int idx = 0;
|
|
|
|
|
for (auto& var : op_proto.inputs()) {
|
|
|
|
|
varmap[var.name()] = idx++;
|
|
|
|
|
}
|
|
|
|
|
idx = 0;
|
|
|
|
|
for (auto& var : op_proto.outputs()) {
|
|
|
|
|
varmap[var.name()] = idx++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename GradOpType>
|
|
|
|
@ -212,22 +202,11 @@ class OpRegistry {
|
|
|
|
|
return grad_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpProto>& protos() {
|
|
|
|
|
static std::unordered_map<std::string, OpProto> protos_;
|
|
|
|
|
return protos_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, std::string>& grad_ops() {
|
|
|
|
|
static std::unordered_map<std::string, std::string> grad_ops_;
|
|
|
|
|
return grad_ops_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
|
|
|
|
|
VarIndexMaps() {
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
|
|
|
|
|
return maps_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpCreator>& op_creators() {
|
|
|
|
|
static std::unordered_map<std::string, OpCreator> op_creators_;
|
|
|
|
|
return op_creators_;
|
|
|
|
|