|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
#include <deque>
|
|
|
|
|
#include <list>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/block_desc.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
@ -285,6 +286,15 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::string FwdName(const std::string& grad_name) {
|
|
|
|
|
auto pos = grad_name.find("@GRAD");
|
|
|
|
|
if (pos == std::string::npos) {
|
|
|
|
|
return "";
|
|
|
|
|
} else {
|
|
|
|
|
return grad_name.substr(0, pos);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void CreateGradVarInBlock(
|
|
|
|
|
size_t grad_op_start_index,
|
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map,
|
|
|
|
@ -294,6 +304,7 @@ static void CreateGradVarInBlock(
|
|
|
|
|
for (size_t op_index = grad_op_start_index; op_index < ops.size();
|
|
|
|
|
++op_index) {
|
|
|
|
|
bool need_infer_shape = false;
|
|
|
|
|
std::unordered_set<std::string> new_vars;
|
|
|
|
|
ForEachVarName(ops[op_index]->Outputs(),
|
|
|
|
|
[&](const std::string& grad_var_name) {
|
|
|
|
|
if (block_desc->HasVar(grad_var_name)) {
|
|
|
|
@ -301,8 +312,7 @@ static void CreateGradVarInBlock(
|
|
|
|
|
}
|
|
|
|
|
need_infer_shape = true;
|
|
|
|
|
auto var = block_desc->Var(grad_var_name);
|
|
|
|
|
// FIXME(qiao) infer the datatype
|
|
|
|
|
var->SetDataType(framework::DataType::FP32);
|
|
|
|
|
new_vars.insert(var->Name());
|
|
|
|
|
auto it = param_name_map.find(grad_var_name);
|
|
|
|
|
if (it == param_name_map.end()) {
|
|
|
|
|
return false;
|
|
|
|
@ -316,6 +326,21 @@ static void CreateGradVarInBlock(
|
|
|
|
|
});
|
|
|
|
|
if (need_infer_shape) {
|
|
|
|
|
ops[op_index]->InferVarType(block_desc);
|
|
|
|
|
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
|
|
|
|
|
if (new_vars.find(arg) == new_vars.end()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto pname = FwdName(arg);
|
|
|
|
|
auto* param = block_desc->FindVar(pname);
|
|
|
|
|
auto* grad = block_desc->FindVar(arg);
|
|
|
|
|
if (param == nullptr) {
|
|
|
|
|
LOG(WARNING) << "Cannot find forward variable of " << arg
|
|
|
|
|
<< ". Set its gradient to FP32";
|
|
|
|
|
grad->SetDataType(DataType::FP32);
|
|
|
|
|
} else {
|
|
|
|
|
grad->SetDataType(param->GetDataType());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ops[op_index]->InferShape(*block_desc);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|