|
|
|
@ -175,17 +175,20 @@ Add a mark to which output is temporary is helpful for future optimization.
|
|
|
|
|
bool has_temporary_output_{false};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class NOPMaker : public OpProtoAndCheckerMaker {};
|
|
|
|
|
class NOPMaker : public OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct OpInfo {
|
|
|
|
|
std::function creator_;
|
|
|
|
|
std::function<OperatorBase*()> creator_;
|
|
|
|
|
std::string grad_op_type_;
|
|
|
|
|
OpProto* proto_;
|
|
|
|
|
OpAttrChecker* checker_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
using OpCreator = std::function<OperatorBase*()>;
|
|
|
|
|
using VarIndexMap = std::unordered_map<std::string, int>;
|
|
|
|
|
using VarNameList = std::vector<std::string>;
|
|
|
|
|
|
|
|
|
@ -201,28 +204,28 @@ class OpRegistry {
|
|
|
|
|
if (std::type_index(typeid(ProtoMakerType)) !=
|
|
|
|
|
std::type_index(typeid(NOPMaker))) {
|
|
|
|
|
op_info.proto_ = new OpProto;
|
|
|
|
|
op_info.op_checker_ = new OpAttrChecker;
|
|
|
|
|
auto maker = ProtoMakerType(op_info.proto_, op_info.op_checker_);
|
|
|
|
|
op_info.checker_ = new OpAttrChecker;
|
|
|
|
|
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
|
|
|
|
|
maker.Validate();
|
|
|
|
|
*op_info.proto_->mutable_type() = op_type;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
op_info.proto_->IsInitialized(),
|
|
|
|
|
"Fail to initialize %s's OpProto, because %s is not initialized",
|
|
|
|
|
op_type, op_info.proto_->InitializationErrorString());
|
|
|
|
|
//======will be refactored in following PRs============//
|
|
|
|
|
// ======will be refactored in following PRs============ //
|
|
|
|
|
VarIndexMaps()[op_type].reset(new VarIndexMap());
|
|
|
|
|
auto& varmap = *VarIndexMaps()[op_type];
|
|
|
|
|
int idx = 0;
|
|
|
|
|
for (auto& var : op_proto.inputs()) {
|
|
|
|
|
for (auto& var : op_info.proto_->inputs()) {
|
|
|
|
|
varmap[var.name()] = idx++;
|
|
|
|
|
}
|
|
|
|
|
idx = 0;
|
|
|
|
|
for (auto& var : op_proto.outputs()) {
|
|
|
|
|
for (auto& var : op_info.proto_->outputs()) {
|
|
|
|
|
varmap[var.name()] = idx++;
|
|
|
|
|
}
|
|
|
|
|
//================================================//
|
|
|
|
|
// ================================================ //
|
|
|
|
|
}
|
|
|
|
|
op_info_map.insert(std::make_pair(op_type, op_info));
|
|
|
|
|
op_info_map().insert(std::make_pair(op_type, op_info));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
@ -281,8 +284,8 @@ class OpRegistry {
|
|
|
|
|
return grad_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<const std::string, const OpInfo>& op_info_map() {
|
|
|
|
|
static std::unordered_map<const std::string, const OpInfo> op_info_map_;
|
|
|
|
|
static std::unordered_map<std::string, const OpInfo>& op_info_map() {
|
|
|
|
|
static std::unordered_map<std::string, const OpInfo> op_info_map_;
|
|
|
|
|
return op_info_map_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -321,7 +324,7 @@ class Registrar {
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
class OpRegistrar : public Registrar {
|
|
|
|
|
public:
|
|
|
|
|
OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
|
|
|
|
|
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
|
|
|
|
|
OpRegistrar(const char* op_type, const char* grad_op_type) {
|
|
|
|
|
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type, grad_op_type);
|
|
|
|
|
}
|
|
|
|
|