|
|
@ -61,17 +61,6 @@ struct OperatorRegistrar : public Registrar {
|
|
|
|
|
|
|
|
|
|
|
|
class OpRegistry {
|
|
|
|
class OpRegistry {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
template <typename OpType, typename ProtoMakerType, typename GradOpType>
|
|
|
|
|
|
|
|
static void RegisterOp(const std::string& op_type,
|
|
|
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
|
|
|
|
OperatorRegistrar<OpType, ProtoMakerType> reg(op_type.c_str());
|
|
|
|
|
|
|
|
reg.info.grad_op_type_ = grad_op_type;
|
|
|
|
|
|
|
|
// register gradient op
|
|
|
|
|
|
|
|
if (!grad_op_type.empty()) {
|
|
|
|
|
|
|
|
OperatorRegistrar<GradOpType> grad_reg(grad_op_type.c_str());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
const VariableNameMap& inputs,
|
|
|
|
const VariableNameMap& inputs,
|
|
|
|
const VariableNameMap& outputs,
|
|
|
|
const VariableNameMap& outputs,
|
|
|
|