revert-3824-remove_grad_op_type
fengjiayi 8 years ago
parent 55fac55107
commit ab08575adf

@ -76,8 +76,16 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
} }
OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_); auto it = op_info_map().find(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type);
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type);
it = op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type);
OperatorBase* grad_op = it->second.creator_();
grad_op->type_ = grad_op_type; grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_; grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format"); grad_op->attrs_.erase("input_format");

File diff suppressed because it is too large Load Diff

@ -173,13 +173,13 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not //! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python. //! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> { m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto &protos = OpRegistry::protos(); auto &op_info_map = OpRegistry::op_info_map();
std::vector<py::bytes> ret_values; std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) { for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(), const OpProto *proto = it->second.proto_;
"OpProto must all be initialized"); PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized");
std::string str; std::string str;
PADDLE_ENFORCE(it->second.SerializeToString(&str), PADDLE_ENFORCE(proto->SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle."); "Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str)); ret_values.push_back(py::bytes(str));
} }

Loading…
Cancel
Save