|
|
|
@ -39,9 +39,9 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
|
|
|
|
|
std::transform(grad_descs.begin(), grad_descs.end(),
|
|
|
|
|
std::back_inserter(grad_ops),
|
|
|
|
|
[](const std::unique_ptr<OpDescBind>& grad_desc) {
|
|
|
|
|
return OpRegistry::CreateOp(grad_desc.get());
|
|
|
|
|
return OpRegistry::CreateOp(*grad_desc);
|
|
|
|
|
});
|
|
|
|
|
PADDLE_ENFORCE_GT(grad_ops.size(), 0);
|
|
|
|
|
PADDLE_ENFORCE(!grad_ops.empty());
|
|
|
|
|
if (grad_ops.size() == 1) {
|
|
|
|
|
return std::move(grad_ops[0]);
|
|
|
|
|
} else {
|
|
|
|
|