|
|
|
@ -28,15 +28,15 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
static inline std::unique_ptr<OperatorBase> CreateGradOp(
|
|
|
|
|
const OperatorBase& op,
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_set) {
|
|
|
|
|
const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
|
OpDescBind op_desc;
|
|
|
|
|
op_desc.SetInputMap(op.Inputs());
|
|
|
|
|
op_desc.SetOutputMap(op.Outputs());
|
|
|
|
|
op_desc.SetType(op.Type());
|
|
|
|
|
op_desc.SetAttrMap(op.Attrs());
|
|
|
|
|
auto& info = OpInfoMap::Instance().Get(op.Type());
|
|
|
|
|
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set);
|
|
|
|
|
auto grad_descs = info.GradOpMaker()(op_desc, no_grad_set, grad_to_var);
|
|
|
|
|
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
|
|
|
|
|
grad_ops.reserve(grad_descs.size());
|
|
|
|
|
std::transform(grad_descs.begin(), grad_descs.end(),
|
|
|
|
@ -99,7 +99,9 @@ static std::unique_ptr<OperatorBase> NOP() {
|
|
|
|
|
// See Backward.h for details
|
|
|
|
|
static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
|
|
|
|
|
std::unordered_set<std::string>& no_grad_names,
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var,
|
|
|
|
|
size_t& uniq_id) {
|
|
|
|
|
// If all input gradients of forwarding operator do not need to calculate,
|
|
|
|
|
// just return an NOP. Not return null ptr because NOP does not take
|
|
|
|
|
// too much time for calculation, but it is useful for simplifying logic.
|
|
|
|
@ -137,7 +139,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
|
|
|
|
|
++it, ++local_op_id) {
|
|
|
|
|
auto& fwd = *it;
|
|
|
|
|
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
|
|
|
|
|
auto bwd = BackwardRecursive(*fwd, no_grad_names, grad_to_var, uniq_id);
|
|
|
|
|
ForEachVarName(bwd->Outputs(),
|
|
|
|
|
[&dup_output_ops, local_op_id](const std::string& out) {
|
|
|
|
|
dup_output_ops[out].emplace_back(local_op_id);
|
|
|
|
@ -189,7 +191,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::unique_ptr<OperatorBase> grad_op(
|
|
|
|
|
CreateGradOp(forwardOp, no_grad_names));
|
|
|
|
|
CreateGradOp(forwardOp, no_grad_names, grad_to_var));
|
|
|
|
|
|
|
|
|
|
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
|
|
|
|
|
const std::string& grad_input) {
|
|
|
|
@ -228,7 +230,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
*static_cast<const OperatorBase*>(&rnnop.stepnet());
|
|
|
|
|
// create stepnet's gradient op
|
|
|
|
|
rnn_grad_op->set_stepnet(
|
|
|
|
|
BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
|
|
|
|
|
BackwardRecursive(stepnet_op, no_grad_names, grad_to_var, uniq_id));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (net->ops_.empty()) { // Current no aux op is added to network
|
|
|
|
@ -255,7 +257,8 @@ std::unique_ptr<OperatorBase> Backward(
|
|
|
|
|
no_grad_names.insert(name + kGradVarSuffix);
|
|
|
|
|
}
|
|
|
|
|
size_t uid = 0;
|
|
|
|
|
return BackwardRecursive(forwardOp, no_grad_names, uid);
|
|
|
|
|
std::unordered_map<std::string, std::string> grad_to_var;
|
|
|
|
|
return BackwardRecursive(forwardOp, no_grad_names, &grad_to_var, uid);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ==================================== //
|
|
|
|
@ -272,30 +275,31 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
|
|
|
|
|
const std::unique_ptr<OpDescBind>& op_desc,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
|
|
|
|
|
// All input gradients of forwarding operator do not need to calculate.
|
|
|
|
|
const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
|
|
|
|
|
if (AllGradInSet(inputs, no_grad_vars)) {
|
|
|
|
|
if (AllGradInSet(inputs, *no_grad_vars)) {
|
|
|
|
|
return grad_op_descs; // empty vector
|
|
|
|
|
}
|
|
|
|
|
// All output gradients of forwarding operator do not need to calculate.
|
|
|
|
|
const std::vector<std::string>& outputs = op_desc->OutputArgumentNames();
|
|
|
|
|
if (AllGradInSet(outputs, no_grad_vars)) {
|
|
|
|
|
if (AllGradInSet(outputs, *no_grad_vars)) {
|
|
|
|
|
for (const std::string& name : inputs) {
|
|
|
|
|
no_grad_vars.insert(GradVarName(name));
|
|
|
|
|
no_grad_vars->insert(GradVarName(name));
|
|
|
|
|
}
|
|
|
|
|
return grad_op_descs; // empty vector
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
grad_op_descs = OpInfoMap::Instance()
|
|
|
|
|
.Get(op_desc->Type())
|
|
|
|
|
.GradOpMaker()(*op_desc, no_grad_vars);
|
|
|
|
|
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var);
|
|
|
|
|
|
|
|
|
|
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
|
|
|
|
|
for (auto& desc : grad_op_descs) {
|
|
|
|
|
for (const std::string& in_name : desc->InputArgumentNames()) {
|
|
|
|
|
if (no_grad_vars.count(in_name)) {
|
|
|
|
|
if (no_grad_vars->count(in_name)) {
|
|
|
|
|
std::string prefix = in_name.substr(
|
|
|
|
|
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
|
|
|
|
|
std::string new_name = prefix + kZeroVarSuffix;
|
|
|
|
@ -315,7 +319,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
|
|
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
ProgramDescBind& program_desc, int block_idx,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
|
BlockDescBind* cur_block = program_desc.Block(block_idx);
|
|
|
|
|
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
|
|
|
|
@ -323,15 +328,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
|
|
|
|
|
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> op_grads =
|
|
|
|
|
MakeOpGrad(*it, no_grad_vars);
|
|
|
|
|
MakeOpGrad(*it, no_grad_vars, grad_to_var);
|
|
|
|
|
|
|
|
|
|
if ((*it)->Type() == "recurrent") {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
op_grads.size(), size_t(1),
|
|
|
|
|
"rnn_op's gradient process should contain only one op.");
|
|
|
|
|
int step_block_idx = (*it)->GetBlockAttr("stop_block");
|
|
|
|
|
auto backward_block_op_descs =
|
|
|
|
|
MakeBlockBackward(program_desc, step_block_idx, no_grad_vars);
|
|
|
|
|
auto backward_block_op_descs = MakeBlockBackward(
|
|
|
|
|
program_desc, step_block_idx, no_grad_vars, grad_to_var);
|
|
|
|
|
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
|
|
|
|
|
for (auto& ptr : backward_block_op_descs) {
|
|
|
|
|
backward_block->ops_.push_back(std::move(ptr));
|
|
|
|
@ -387,8 +392,9 @@ void AppendBackward(ProgramDescBind& program_desc,
|
|
|
|
|
no_grad_var_names.insert(GradVarName(name));
|
|
|
|
|
}
|
|
|
|
|
const int root_block_idx = 0;
|
|
|
|
|
auto backward_op_descs =
|
|
|
|
|
MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names);
|
|
|
|
|
std::unordered_map<std::string, std::string> grad_to_var;
|
|
|
|
|
auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
|
|
|
|
|
&no_grad_var_names, &grad_to_var);
|
|
|
|
|
auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
|
|
|
|
|
for (auto& ptr : backward_op_descs) {
|
|
|
|
|
forw_op_descs.push_back(std::move(ptr));
|
|
|
|
|