|
|
|
@ -270,6 +270,19 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
std::ostringstream sout;
|
|
|
|
|
sout << "All input {";
|
|
|
|
|
for (auto& name : names) {
|
|
|
|
|
sout << name << ",";
|
|
|
|
|
}
|
|
|
|
|
sout << "} is in {";
|
|
|
|
|
for (auto& name : set) {
|
|
|
|
|
sout << name << ",";
|
|
|
|
|
}
|
|
|
|
|
sout << "}";
|
|
|
|
|
VLOG(10) << sout.str();
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -290,14 +303,12 @@ static void CreateGradVarInBlock(
|
|
|
|
|
auto ops = block_desc->AllOps();
|
|
|
|
|
for (size_t op_index = grad_op_start_index; op_index < ops.size();
|
|
|
|
|
++op_index) {
|
|
|
|
|
bool need_infer_shape = false;
|
|
|
|
|
std::unordered_set<std::string> new_vars;
|
|
|
|
|
ForEachVarName(ops[op_index]->Outputs(),
|
|
|
|
|
[&](const std::string& grad_var_name) {
|
|
|
|
|
if (block_desc->HasVar(grad_var_name)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
need_infer_shape = true;
|
|
|
|
|
auto var = block_desc->Var(grad_var_name);
|
|
|
|
|
new_vars.insert(var->Name());
|
|
|
|
|
auto it = param_name_map.find(grad_var_name);
|
|
|
|
@ -311,23 +322,21 @@ static void CreateGradVarInBlock(
|
|
|
|
|
grad_record.op_idx_ = static_cast<int>(op_index);
|
|
|
|
|
return false; /* not break */
|
|
|
|
|
});
|
|
|
|
|
if (need_infer_shape) {
|
|
|
|
|
ops[op_index]->InferVarType(block_desc);
|
|
|
|
|
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
|
|
|
|
|
if (new_vars.find(arg) == new_vars.end()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto pname = FwdName(arg);
|
|
|
|
|
auto* param = block_desc->FindVarRecursive(pname);
|
|
|
|
|
auto* grad = block_desc->FindVar(arg);
|
|
|
|
|
if (param == nullptr) {
|
|
|
|
|
grad->SetDataType(DataType::FP32);
|
|
|
|
|
} else {
|
|
|
|
|
grad->SetDataType(param->GetDataType());
|
|
|
|
|
}
|
|
|
|
|
ops[op_index]->InferVarType(block_desc);
|
|
|
|
|
for (auto& arg : ops[op_index]->OutputArgumentNames()) {
|
|
|
|
|
if (new_vars.find(arg) == new_vars.end()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto pname = FwdName(arg);
|
|
|
|
|
auto* param = block_desc->FindVarRecursive(pname);
|
|
|
|
|
auto* grad = block_desc->FindVar(arg);
|
|
|
|
|
if (param == nullptr) {
|
|
|
|
|
grad->SetDataType(DataType::FP32);
|
|
|
|
|
} else {
|
|
|
|
|
grad->SetDataType(param->GetDataType());
|
|
|
|
|
}
|
|
|
|
|
ops[op_index]->InferShape(*block_desc);
|
|
|
|
|
}
|
|
|
|
|
ops[op_index]->InferShape(*block_desc);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -387,6 +396,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
ProgramDescBind& program_desc, int block_idx,
|
|
|
|
|
std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
|
VLOG(5) << "MakeBlockBackward";
|
|
|
|
|
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
|
|
|
|
|
std::vector<OpDescBind*> op_descs = cur_block->AllOps();
|
|
|
|
|
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
|
|
|
|
@ -394,9 +404,10 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
|
|
|
|
|
|
|
|
|
|
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
|
|
|
|
|
VLOG(5) << "Making backward " << (*it)->Type() << " op";
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> op_grads;
|
|
|
|
|
|
|
|
|
|
if ((*it)->Type() == "recurrent") {
|
|
|
|
|
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
|
|
|
|
|
int step_block_idx = (*it)->GetBlockAttr("step_block");
|
|
|
|
|
BlockDescBind* backward_block = CreateStepBlock(
|
|
|
|
|
program_desc, no_grad_vars, grad_to_var, step_block_idx);
|
|
|
|
@ -410,6 +421,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
std::ostringstream sout;
|
|
|
|
|
sout << "Made ";
|
|
|
|
|
for (auto& op_grad : op_grads) {
|
|
|
|
|
sout << op_grad->Type() << " ";
|
|
|
|
|
}
|
|
|
|
|
VLOG(10) << sout.str();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto& desc : op_grads) {
|
|
|
|
|
for (const std::string& out_name : desc->OutputArgumentNames()) {
|
|
|
|
|
if (out_name.find("@GRAD") == std::string::npos) {
|
|
|
|
@ -425,6 +445,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
|
|
|
|
|
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(5) << "Appending Sums";
|
|
|
|
|
// Check whether some variables are written more than once
|
|
|
|
|
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
|
|
|
|
|
for (const auto& dup : dup_out_ops) {
|
|
|
|
@ -432,16 +454,22 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
const std::vector<size_t> dup_op = dup.second;
|
|
|
|
|
if (out_name != kEmptyVarName && dup_op.size() > 1) {
|
|
|
|
|
std::vector<std::string> sum_op_inputs;
|
|
|
|
|
std::string next_g_name = out_name;
|
|
|
|
|
for (size_t i = 0; i < dup_op.size(); ++i) {
|
|
|
|
|
VLOG(10) << backward_descs[dup_op[i]]->Type() << " has " << out_name
|
|
|
|
|
<< " duplicated";
|
|
|
|
|
std::string new_name = out_name + "@RENAME@" + std::to_string(i);
|
|
|
|
|
backward_descs[dup_op[i]]->Rename(out_name, new_name);
|
|
|
|
|
backward_descs[dup_op[i]]->RenameOutput(out_name, new_name);
|
|
|
|
|
backward_descs[dup_op[i]]->RenameInput(out_name, next_g_name);
|
|
|
|
|
sum_op_inputs.emplace_back(new_name);
|
|
|
|
|
next_g_name = sum_op_inputs.back();
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
|
|
|
|
|
"sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
|
|
|
|
|
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pending_sum_ops.sort(
|
|
|
|
|
[](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
|
|
|
|
|
const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
|
|
|
|
@ -452,6 +480,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
|
std::move(p.second));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(5) << "MakeBlockBackward Finished";
|
|
|
|
|
|
|
|
|
|
return backward_descs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|