|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
|
|
|
|
|
#include <list>
|
|
|
|
|
#include <memory>
|
|
|
|
@ -24,6 +25,32 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
static inline std::unique_ptr<OperatorBase> CreateGradOp(
|
|
|
|
|
const OperatorBase& op) {
|
|
|
|
|
OpDescBind op_desc;
|
|
|
|
|
op_desc.SetInputMap(op.Inputs());
|
|
|
|
|
op_desc.SetOutputMap(op.Outputs());
|
|
|
|
|
op_desc.SetType(op.Type());
|
|
|
|
|
op_desc.SetAttrMap(op.Attrs());
|
|
|
|
|
auto& info = OpInfoMap::Instance().Get(op.Type());
|
|
|
|
|
auto grad_descs = info.grad_op_maker_(op_desc);
|
|
|
|
|
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
|
|
|
|
|
grad_ops.reserve(grad_descs.size());
|
|
|
|
|
std::transform(
|
|
|
|
|
grad_descs.begin(), grad_descs.end(), std::back_inserter(grad_ops),
|
|
|
|
|
[](OpDescBind& grad_desc) { return OpRegistry::CreateOp(&grad_desc); });
|
|
|
|
|
PADDLE_ENFORCE_GT(grad_ops.size(), 0);
|
|
|
|
|
if (grad_ops.size() == 1) {
|
|
|
|
|
return std::move(grad_ops[0]);
|
|
|
|
|
} else {
|
|
|
|
|
auto net_op = new operators::NetOp();
|
|
|
|
|
for (auto& grad_op : grad_ops) {
|
|
|
|
|
net_op->AppendOp(std::move(grad_op));
|
|
|
|
|
}
|
|
|
|
|
return std::unique_ptr<OperatorBase>(net_op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Map, typename T>
|
|
|
|
|
static void ForEachVarName(const Map& names, T callback) {
|
|
|
|
|
for (auto& name : names) {
|
|
|
|
@ -154,10 +181,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
net->InsertOp(pos.first + 1, std::move(pos.second));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
OpDescBind fwd_desc;
|
|
|
|
|
fwd_desc.SetInput(forwardOp.Inputs());
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
|
|
|
|
|
std::unique_ptr<OperatorBase> grad_op(CreateGradOp(forwardOp));
|
|
|
|
|
PADDLE_ENFORCE(grad_op != nullptr);
|
|
|
|
|
|
|
|
|
|
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
|
|
|
|
|
const std::string& grad_input) {
|
|
|
|
|