|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
#include <deque>
|
|
|
|
|
#include <list>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string FwdName(const std::string& grad_name) {
|
|
|
|
|
auto pos = grad_name.find("@GRAD");
|
|
|
|
|
if (pos == std::string::npos) {
|
|
|
|
|
return "";
|
|
|
|
|
} else {
|
|
|
|
|
return grad_name.substr(0, pos);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CreateGradVarInBlock(
|
|
|
|
|
size_t grad_op_start_index,
|
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map,
|
|
|
|
@ -294,6 +304,7 @@ static void CreateGradVarInBlock(
|
|
|
|
|
for (size_t op_index = grad_op_start_index; op_index < ops.size();
|
|
|
|
|
++op_index) {
|
|
|
|
|
bool need_infer_shape = false;
|
|
|
|
|
std::unordered_set<std::string> new_vars;
|
|
|
|
|
ForEachVarName(ops[op_index]->Outputs(),
|
|
|
|
|
[&](const std::string& grad_var_name) {
|
|
|
|
|
if (block_desc->HasVar(grad_var_name)) {
|
|
|
|
@ -301,8 +312,7 @@ static void CreateGradVarInBlock(
|
|
|
|
|
}
|
|
|
|
|
need_infer_shape = true;
|
|
|
|
|
auto var = block_desc->Var(grad_var_name);
|
|
|
|
|
// FIXME(qiao) infer the datatype
|
|
|
|
|
var->SetDataType(framework::DataType::FP32);
|
|
|
|
|
new_vars.insert(var->Name());
|
|
|
|
|
auto it = param_name_map.find(grad_var_name);
|
|
|
|
|
if (it == param_name_map.end()) {
|
|
|
|
|
return false;
|
|
|
|
@ -316,6 +326,21 @@ static void CreateGradVarInBlock(
|
|
|
|
|
});
|
|
|
|
|
if (need_infer_shape) {
|
|
|
|
|
ops[op_index]->InferVarType(block_desc);
|
|
|
|
|
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
|
|
|
|
|
if (new_vars.find(arg) == new_vars.end()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto pname = FwdName(arg);
|
|
|
|
|
auto* param = block_desc->FindVar(pname);
|
|
|
|
|
auto* grad = block_desc->FindVar(arg);
|
|
|
|
|
if (param == nullptr) {
|
|
|
|
|
LOG(WARNING) << "Cannot find forward variable of " << arg
|
|
|
|
|
<< ". Set its gradient to FP32";
|
|
|
|
|
grad->SetDataType(DataType::FP32);
|
|
|
|
|
} else {
|
|
|
|
|
grad->SetDataType(param->GetDataType());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ops[op_index]->InferShape(*block_desc);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -368,7 +393,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
ProgramDescBind& program_desc, int block_idx,
|
|
|
|
|
std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
|
BlockDescBind* cur_block = program_desc.Block(block_idx);
|
|
|
|
|
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
|
|
|
|
|
std::vector<OpDescBind*> op_descs = cur_block->AllOps();
|
|
|
|
|
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
|
|
|
|
|
size_t grad_desc_idx = 0;
|
|
|
|
@ -443,7 +468,7 @@ ParamGradInfoMap AppendBackward(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int root_block_idx = 0;
|
|
|
|
|
auto root_block = program_desc.Block(root_block_idx);
|
|
|
|
|
auto root_block = program_desc.MutableBlock(root_block_idx);
|
|
|
|
|
|
|
|
|
|
// insert fill one op for target
|
|
|
|
|
// TODO(qiao) add some check to the target.
|
|
|
|
@ -492,7 +517,7 @@ ParamGradInfoMap AppendBackward(
|
|
|
|
|
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
|
|
|
|
|
for (size_t block_index = forward_block_num;
|
|
|
|
|
block_index < program_desc.Size(); ++block_index) {
|
|
|
|
|
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index),
|
|
|
|
|
CreateGradVarInBlock(0, grad_to_var, program_desc.MutableBlock(block_index),
|
|
|
|
|
&retv);
|
|
|
|
|
}
|
|
|
|
|
return retv;
|
|
|
|
|