|
|
|
@ -20,6 +20,7 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/framework/program_desc.h"
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
#include "paddle/operators/recurrent_op.h"
|
|
|
|
|
|
|
|
|
@ -254,7 +255,7 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
|
|
|
|
|
const std::unique_ptr<OpDescBind>& op_desc,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
|
|
|
|
@ -295,20 +296,35 @@ std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
|
|
|
|
|
for (auto& p : pending_fill_zeros_ops) {
|
|
|
|
|
grad_op_descs.push_back(std::move(p));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO(fengjiayi): RNN op
|
|
|
|
|
return grad_op_descs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AppendBackwardOpDescs(BlockDescBind& block_desc,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
ProgramDescBind& program_desc, int block_idx,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
BlockDescBind* cur_block = program_desc.Block(block_idx);
|
|
|
|
|
std::deque<std::unique_ptr<OpDescBind>>& op_descs = cur_block->ops_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
|
|
|
|
|
size_t grad_desc_idx = 0;
|
|
|
|
|
std::deque<std::unique_ptr<OpDescBind>>& block_op_descs = block_desc.ops_;
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
|
|
|
|
|
for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) {
|
|
|
|
|
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> op_grads =
|
|
|
|
|
MakeGradOpDescs(*it, no_grad_vars);
|
|
|
|
|
MakeOpGrad(*it, no_grad_vars);
|
|
|
|
|
|
|
|
|
|
if ((*it)->Type() == "recurrent") {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
op_grads.size(), size_t(1),
|
|
|
|
|
"rnn_op's gradient process should contain only one op.");
|
|
|
|
|
int step_block_idx = (*it)->GetBlockAttr("stop_block");
|
|
|
|
|
auto backward_block_op_descs =
|
|
|
|
|
MakeBlockBackward(program_desc, step_block_idx, no_grad_vars);
|
|
|
|
|
BlockDescBind* backward_block = program_desc.AppendBlock(*cur_block);
|
|
|
|
|
for (auto& ptr : backward_block_op_descs) {
|
|
|
|
|
backward_block->ops_.push_back(std::move(ptr));
|
|
|
|
|
}
|
|
|
|
|
op_grads[0]->SetBlockAttr("step_block", *backward_block);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& desc : op_grads) {
|
|
|
|
|
for (const std::string& out_name : desc->OutputArgumentNames()) {
|
|
|
|
|
dup_out_ops[out_name].emplace_back(grad_desc_idx);
|
|
|
|
@ -345,11 +361,24 @@ void AppendBackwardOpDescs(BlockDescBind& block_desc,
|
|
|
|
|
backward_descs.insert(backward_descs.begin() + p.first + 1,
|
|
|
|
|
std::move(p.second));
|
|
|
|
|
}
|
|
|
|
|
// Append backward_descs to BlockDescBind::ops_
|
|
|
|
|
for (std::unique_ptr<OpDescBind>& ptr : backward_descs) {
|
|
|
|
|
block_op_descs.push_back(std::move(ptr));
|
|
|
|
|
return backward_descs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AppendBackward(ProgramDescBind& program_desc,
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::unordered_set<std::string> no_grad_var_names;
|
|
|
|
|
no_grad_var_names.reserve(no_grad_vars.size() + 1);
|
|
|
|
|
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
|
|
|
|
|
for (auto& name : no_grad_vars) {
|
|
|
|
|
no_grad_var_names.insert(GradVarName(name));
|
|
|
|
|
}
|
|
|
|
|
const int root_block_idx = 0;
|
|
|
|
|
auto backward_op_descs =
|
|
|
|
|
MakeBlockBackward(program_desc, root_block_idx, no_grad_var_names);
|
|
|
|
|
auto& forw_op_descs = program_desc.Block(root_block_idx)->ops_;
|
|
|
|
|
for (auto& ptr : backward_op_descs) {
|
|
|
|
|
forw_op_descs.push_back(std::move(ptr));
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|