|
|
|
@ -274,9 +274,10 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CreateGradVarInBlock(
|
|
|
|
|
std::unordered_map<std::string, GradVarInfo>* grad_var_record,
|
|
|
|
|
BlockDescBind* block_desc, size_t grad_op_start_index,
|
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map) {
|
|
|
|
|
size_t grad_op_start_index,
|
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map,
|
|
|
|
|
BlockDescBind* block_desc,
|
|
|
|
|
std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
|
|
|
|
|
auto ops = block_desc->AllOps();
|
|
|
|
|
for (size_t op_index = grad_op_start_index; op_index < ops.size();
|
|
|
|
|
++op_index) {
|
|
|
|
@ -451,11 +452,11 @@ ParamGradInfoMap AppendBackward(
|
|
|
|
|
root_block->NewVar(fill_one_op_out);
|
|
|
|
|
|
|
|
|
|
// create grad_var for all blocks in this program
|
|
|
|
|
CreateGradVarInBlock(&retv, root_block, forward_op_num, grad_to_var);
|
|
|
|
|
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
|
|
|
|
|
for (size_t block_index = forward_block_num;
|
|
|
|
|
block_index < program_desc.Size(); ++block_index) {
|
|
|
|
|
CreateGradVarInBlock(&retv, program_desc.Block(block_index), 0,
|
|
|
|
|
grad_to_var);
|
|
|
|
|
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
|
|
|
|
|
&retv);
|
|
|
|
|
}
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|