|
|
@ -15,6 +15,8 @@
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
|
|
|
|
|
|
|
|
#include <list>
|
|
|
|
#include <list>
|
|
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
#include "paddle/operators/recurrent_op.h"
|
|
|
|
#include "paddle/operators/recurrent_op.h"
|
|
|
@ -43,11 +45,11 @@ static bool AllInSet(
|
|
|
|
return all_in_set;
|
|
|
|
return all_in_set;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> NOP() {
|
|
|
|
static std::unique_ptr<OperatorBase> NOP() {
|
|
|
|
auto net_op = std::make_shared<operators::NetOp>();
|
|
|
|
auto net_op = new operators::NetOp();
|
|
|
|
net_op->SetType("@NOP@");
|
|
|
|
net_op->SetType("@NOP@");
|
|
|
|
net_op->CompleteAddOp();
|
|
|
|
net_op->CompleteAddOp();
|
|
|
|
return net_op;
|
|
|
|
return std::unique_ptr<OperatorBase>(net_op);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Get backward operator from a forward operator, a recursive implementation.
|
|
|
|
// Get backward operator from a forward operator, a recursive implementation.
|
|
|
@ -62,11 +64,7 @@ static std::shared_ptr<OperatorBase> NOP() {
|
|
|
|
// operator, in a complex situation, it maybe a NetOp.
|
|
|
|
// operator, in a complex situation, it maybe a NetOp.
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// See Backward.h for details
|
|
|
|
// See Backward.h for details
|
|
|
|
static std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
|
|
|
|
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
|
|
|
|
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
|
|
|
|
// If all input gradients of forwarding operator do not need to calculate,
|
|
|
|
// If all input gradients of forwarding operator do not need to calculate,
|
|
|
@ -91,7 +89,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Returned gradient network
|
|
|
|
// Returned gradient network
|
|
|
|
auto net = std::make_shared<operators::NetOp>();
|
|
|
|
auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
|
|
|
|
|
|
|
|
|
|
|
|
if (forwardOp.IsNetOp()) {
|
|
|
|
if (forwardOp.IsNetOp()) {
|
|
|
|
// Because forwardOp is a net op, it can static_cast.
|
|
|
|
// Because forwardOp is a net op, it can static_cast.
|
|
|
@ -105,14 +103,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
// reversely travel forwardNet and collect all duplicate outputs.
|
|
|
|
// reversely travel forwardNet and collect all duplicate outputs.
|
|
|
|
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
|
|
|
|
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
|
|
|
|
++it, ++local_op_id) {
|
|
|
|
++it, ++local_op_id) {
|
|
|
|
auto fwd = *it;
|
|
|
|
auto& fwd = *it;
|
|
|
|
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
|
|
|
|
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
|
|
|
|
net->AddOp(bwd);
|
|
|
|
|
|
|
|
ForEachVarName(bwd->Outputs(),
|
|
|
|
ForEachVarName(bwd->Outputs(),
|
|
|
|
[&dup_output_ops, local_op_id](const std::string& out) {
|
|
|
|
[&dup_output_ops, local_op_id](const std::string& out) {
|
|
|
|
dup_output_ops[out].emplace_back(local_op_id);
|
|
|
|
dup_output_ops[out].emplace_back(local_op_id);
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
});
|
|
|
|
});
|
|
|
|
|
|
|
|
net->AddOp(std::move(bwd));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Get unique ID for this method.
|
|
|
|
// Get unique ID for this method.
|
|
|
|
auto uid = uniq_id++;
|
|
|
|
auto uid = uniq_id++;
|
|
|
@ -122,7 +120,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
// to handle this case. For each duplicate output, rename it to an alias
|
|
|
|
// to handle this case. For each duplicate output, rename it to an alias
|
|
|
|
// (original name with a offset), append an `add` op for its operator,
|
|
|
|
// (original name with a offset), append an `add` op for its operator,
|
|
|
|
// and finally sum all the alias variable to the final output variable y.
|
|
|
|
// and finally sum all the alias variable to the final output variable y.
|
|
|
|
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
|
|
|
|
using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
|
|
|
|
std::list<Pos> insert_position;
|
|
|
|
std::list<Pos> insert_position;
|
|
|
|
for (auto& dup_output_op : dup_output_ops) {
|
|
|
|
for (auto& dup_output_op : dup_output_ops) {
|
|
|
|
const std::string& name = dup_output_op.first;
|
|
|
|
const std::string& name = dup_output_op.first;
|
|
|
@ -150,13 +148,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
[](const Pos& l, const Pos& r) { return l.first > r.first; });
|
|
|
|
[](const Pos& l, const Pos& r) { return l.first > r.first; });
|
|
|
|
|
|
|
|
|
|
|
|
for (auto& pos : insert_position) {
|
|
|
|
for (auto& pos : insert_position) {
|
|
|
|
net->InsertOp(pos.first + 1, pos.second);
|
|
|
|
net->InsertOp(pos.first + 1, std::move(pos.second));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
|
|
|
|
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
|
|
|
|
|
|
|
|
|
|
|
|
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net,
|
|
|
|
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
|
|
|
|
grad_op](const std::string& grad_input) {
|
|
|
|
const std::string& grad_input) {
|
|
|
|
if (no_grad_names.count(grad_input)) {
|
|
|
|
if (no_grad_names.count(grad_input)) {
|
|
|
|
// +1 for \0
|
|
|
|
// +1 for \0
|
|
|
|
std::string prefix = grad_input.substr(
|
|
|
|
std::string prefix = grad_input.substr(
|
|
|
@ -190,23 +188,23 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
const auto& stepnet_op =
|
|
|
|
const auto& stepnet_op =
|
|
|
|
*static_cast<const OperatorBase*>(&rnnop.stepnet());
|
|
|
|
*static_cast<const OperatorBase*>(&rnnop.stepnet());
|
|
|
|
// create stepnet's gradient op
|
|
|
|
// create stepnet's gradient op
|
|
|
|
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
|
|
|
|
|
|
|
|
rnn_grad_op->set_stepnet(
|
|
|
|
rnn_grad_op->set_stepnet(
|
|
|
|
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
|
|
|
|
BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (net->ops_.empty()) { // Current no aux op is added to network
|
|
|
|
if (net->ops_.empty()) { // Current no aux op is added to network
|
|
|
|
return grad_op;
|
|
|
|
return grad_op;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
net->AddOp(grad_op);
|
|
|
|
net->AddOp(std::move(grad_op));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
net->SetType("@GENERATED_BACKWARD@");
|
|
|
|
net->SetType("@GENERATED_BACKWARD@");
|
|
|
|
net->CompleteAddOp();
|
|
|
|
net->CompleteAddOp();
|
|
|
|
return net;
|
|
|
|
return std::unique_ptr<OperatorBase>(
|
|
|
|
} // namespace framework
|
|
|
|
static_cast<OperatorBase*>(net.release()));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// See header for comments
|
|
|
|
// See header for comments
|
|
|
|
std::shared_ptr<OperatorBase> Backward(
|
|
|
|
std::unique_ptr<OperatorBase> Backward(
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
std::unordered_set<std::string> no_grad_names;
|
|
|
|
std::unordered_set<std::string> no_grad_names;
|
|
|
|