|
|
|
@ -120,13 +120,19 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
using OpCreator = std::function<OperatorBase*()>;
|
|
|
|
|
using VarNameMap = OperatorBase::VarNameMap;
|
|
|
|
|
using OpCreator = std::function<OperatorBase*(
|
|
|
|
|
const std::string& /*type*/, const VarNameMap& /*inputs*/,
|
|
|
|
|
const VarNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
static void RegisterOp(const std::string& op_type) {
|
|
|
|
|
op_creators()[op_type] = [] { return new OpType; };
|
|
|
|
|
op_creators()[op_type] = [](
|
|
|
|
|
const std::string& type, const VarNameMap& inputs,
|
|
|
|
|
const VarNameMap& outputs, const AttributeMap& attrs) {
|
|
|
|
|
return new OpType(type, inputs, outputs, attrs);
|
|
|
|
|
};
|
|
|
|
|
OpAttrChecker& op_checker = op_checkers()[op_type];
|
|
|
|
|
OpProto& op_proto = OpProtos()[op_type];
|
|
|
|
|
auto maker = ProtoMakerType(&op_proto, &op_checker);
|
|
|
|
@ -141,29 +147,25 @@ class OpRegistry {
|
|
|
|
|
template <typename GradOpType>
|
|
|
|
|
static void RegisterGradOp(const std::string& op_type,
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
|
op_creators()[grad_op_type] = [] { return new GradOpType; };
|
|
|
|
|
op_creators()[grad_op_type] = [](
|
|
|
|
|
const std::string& type, const VarNameMap& inputs,
|
|
|
|
|
const VarNameMap& outputs, const AttributeMap& attrs) {
|
|
|
|
|
return new GradOpType(type, inputs, outputs, attrs);
|
|
|
|
|
};
|
|
|
|
|
grad_ops()[op_type] = grad_op_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
|
const VarNameMap& inputs,
|
|
|
|
|
const VarNameMap& outputs,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
AttributeMap attrs) {
|
|
|
|
|
auto op_create_it = op_creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != op_creators().end(),
|
|
|
|
|
"Operator %s cannot be found.", type);
|
|
|
|
|
op_checkers().at(type).Check(attrs);
|
|
|
|
|
|
|
|
|
|
auto op = op_create_it->second();
|
|
|
|
|
op->type_ = type;
|
|
|
|
|
op->inputs_ = inputs;
|
|
|
|
|
op->outputs_ = outputs;
|
|
|
|
|
|
|
|
|
|
op->attrs_ = attrs;
|
|
|
|
|
op_checkers().at(type).Check(op->attrs_);
|
|
|
|
|
|
|
|
|
|
GenerateTempVariableName(op);
|
|
|
|
|
auto op = op_create_it->second(type, inputs, outputs, attrs);
|
|
|
|
|
|
|
|
|
|
op->Init();
|
|
|
|
|
return std::shared_ptr<OperatorBase>(op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -195,7 +197,6 @@ class OpRegistry {
|
|
|
|
|
PADDLE_ENFORCE(!op.IsNetOp(),
|
|
|
|
|
"Use framework::Backward to get backward ops");
|
|
|
|
|
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
|
|
|
|
|
grad_op->Init();
|
|
|
|
|
return grad_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -214,19 +215,6 @@ class OpRegistry {
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
|
|
|
|
|
return op_checkers_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void GenerateTempVariableName(OperatorBase* op) {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& output : op->outputs_) {
|
|
|
|
|
for (auto& output_name : output.second) {
|
|
|
|
|
if (output_name == kTempVarName) {
|
|
|
|
|
output_name += op->type_;
|
|
|
|
|
output_name += "@";
|
|
|
|
|
output_name += std::to_string(gUniqId.fetch_add(1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Registrar {
|
|
|
|
|