|
|
@ -76,7 +76,7 @@ static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
OperatorBase* BuildGradOp(const OperatorBase* op) {
|
|
|
|
OperatorBase* BuildGradOp(const OperatorBase* op) {
|
|
|
|
std::string grad_op_type = OpRegistry::grad_ops().at(op->type_);
|
|
|
|
const std::string& grad_op_type = OpRegistry::grad_ops().at(op->Type());
|
|
|
|
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
|
|
|
|
OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)();
|
|
|
|
grad_op->type_ = grad_op_type;
|
|
|
|
grad_op->type_ = grad_op_type;
|
|
|
|
grad_op->attrs_ = op->attrs_;
|
|
|
|
grad_op->attrs_ = op->attrs_;
|
|
|
|