revert-3824-remove_grad_op_type
fengjiayi 8 years ago
parent 55fac55107
commit ab08575adf

@ -76,8 +76,16 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
}
OperatorBase* BuildGradOp(const OperatorBase* op) {
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
auto it = op_info_map().find(op->type_);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", op->type);
std::string grad_op_type = it->second.grad_op_type_;
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
op->type);
it = op_info_map().find(grad_op_type);
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
"'%s' has not been registered.", grad_op_type);
OperatorBase* grad_op = it->second.creator_();
grad_op->type_ = grad_op_type;
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");

File diff suppressed because it is too large Load Diff

@ -173,13 +173,13 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
auto &protos = OpRegistry::protos();
auto &op_info_map = OpRegistry::op_info_map();
std::vector<py::bytes> ret_values;
for (auto it = protos.begin(); it != protos.end(); ++it) {
PADDLE_ENFORCE(it->second.IsInitialized(),
"OpProto must all be initialized");
for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) {
const OpProto *proto = it->second.proto_;
PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized");
std::string str;
PADDLE_ENFORCE(it->second.SerializeToString(&str),
PADDLE_ENFORCE(proto->SerializeToString(&str),
"Serialize OpProto Error. This could be a bug of Paddle.");
ret_values.push_back(py::bytes(str));
}

Loading…
Cancel
Save