|
|
|
@ -273,18 +273,30 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CreateGradVarInBlock(BlockDescBind* block_desc,
|
|
|
|
|
size_t grad_op_start_index) {
|
|
|
|
|
static void CreateGradVarInBlock(
|
|
|
|
|
std::unordered_map<std::string, GradVarInfo>* grad_var_record,
|
|
|
|
|
BlockDescBind* block_desc, size_t grad_op_start_index,
|
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map) {
|
|
|
|
|
auto ops = block_desc->AllOps();
|
|
|
|
|
for (size_t op_index = grad_op_start_index; op_index < ops.size();
|
|
|
|
|
++op_index) {
|
|
|
|
|
for (const auto& output : ops[op_index]->Outputs()) {
|
|
|
|
|
for (const auto& real_output : output.second) {
|
|
|
|
|
if (!block_desc->HasVar(real_output)) {
|
|
|
|
|
block_desc->NewVar(real_output);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ForEachVarName(ops[op_index]->Outputs(),
|
|
|
|
|
[&](const std::string& grad_var_name) {
|
|
|
|
|
if (block_desc->HasVar(grad_var_name)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
block_desc->NewVar(grad_var_name);
|
|
|
|
|
auto it = param_name_map.find(grad_var_name);
|
|
|
|
|
if (it == param_name_map.end()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto param_var_name = it->second;
|
|
|
|
|
auto& grad_record = (*grad_var_record)[param_var_name];
|
|
|
|
|
grad_record.name_ = grad_var_name;
|
|
|
|
|
grad_record.block_idx_ = block_desc->ID();
|
|
|
|
|
grad_record.op_idx_ = static_cast<int>(op_index);
|
|
|
|
|
return false; /* not break */
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -400,8 +412,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
return backward_descs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
|
|
|
|
|
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::unordered_set<std::string> no_grad_var_names;
|
|
|
|
|
no_grad_var_names.reserve(no_grad_vars.size() + 1);
|
|
|
|
|
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
|
|
|
|
@ -423,20 +436,28 @@ void AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
|
|
|
|
|
all_ops.push_back(std::move(fill_one_op));
|
|
|
|
|
size_t forward_op_num = all_ops.size();
|
|
|
|
|
size_t forward_block_num = program_desc.Size();
|
|
|
|
|
|
|
|
|
|
// Insert backward operators
|
|
|
|
|
std::unordered_map<std::string, std::string> grad_to_var;
|
|
|
|
|
auto backward_op_descs = MakeBlockBackward(program_desc, root_block_idx,
|
|
|
|
|
&no_grad_var_names, &grad_to_var);
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, GradVarInfo> retv;
|
|
|
|
|
|
|
|
|
|
// Create Variable
|
|
|
|
|
for (auto& ptr : backward_op_descs) {
|
|
|
|
|
all_ops.push_back(std::move(ptr));
|
|
|
|
|
}
|
|
|
|
|
root_block->NewVar(fill_one_op_out);
|
|
|
|
|
|
|
|
|
|
// create grad_var for all blocks in this program
|
|
|
|
|
CreateGradVarInBlock(root_block, forward_op_num);
|
|
|
|
|
CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var);
|
|
|
|
|
for (size_t block_index = forward_block_num;
|
|
|
|
|
block_index < program_desc.Size(); ++block_index) {
|
|
|
|
|
CreateGradVarInBlock(program_desc.Block(block_index), 0);
|
|
|
|
|
CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0,
|
|
|
|
|
grad_to_var);
|
|
|
|
|
}
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|