|
|
|
@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CreateGradVarInBlock(
|
|
|
|
|
std::unordered_map<std::string, GradVarInfo>* grad_var_record,
|
|
|
|
|
BlockDescBind* block_desc, size_t grad_op_start_index,
|
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map) {
|
|
|
|
|
size_t grad_op_start_index,
|
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map,
|
|
|
|
|
BlockDescBind* block_desc,
|
|
|
|
|
std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
|
|
|
|
|
auto ops = block_desc->AllOps();
|
|
|
|
|
for (size_t op_index = grad_op_start_index; op_index < ops.size();
|
|
|
|
|
++op_index) {
|
|
|
|
@ -422,9 +423,9 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
return backward_descs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string /*fwd_var_name*/, GradVarInfo /*grad_var_info*/>
|
|
|
|
|
AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
ParamGradInfoMap AppendBackward(
|
|
|
|
|
ProgramDescBind& program_desc, const VarDescBind& target,
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::unordered_set<std::string> no_grad_var_names;
|
|
|
|
|
no_grad_var_names.reserve(no_grad_vars.size() + 1);
|
|
|
|
|
no_grad_var_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
|
|
|
|
@ -461,11 +462,11 @@ AppendBackward(ProgramDescBind& program_desc, const VarDescBind& target,
|
|
|
|
|
root_block->Var(fill_one_op_out);
|
|
|
|
|
|
|
|
|
|
// create grad_var for all blocks in this program
|
|
|
|
|
CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var);
|
|
|
|
|
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
|
|
|
|
|
for (size_t block_index = forward_block_num;
|
|
|
|
|
block_index < program_desc.Size(); ++block_index) {
|
|
|
|
|
CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0,
|
|
|
|
|
grad_to_var);
|
|
|
|
|
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
|
|
|
|
|
&retv);
|
|
|
|
|
}
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|