|
|
|
@ -387,8 +387,8 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
auto &p_names = Inputs(kParameters);
|
|
|
|
|
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
|
|
|
|
|
|
|
|
|
|
for (size_t prog_id = 0; prog_id < pg_names.size(); ++prog_id) {
|
|
|
|
|
auto inside_grad_name = framework::GradVarName(p_names[prog_id]);
|
|
|
|
|
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
|
|
|
|
|
auto inside_grad_name = framework::GradVarName(p_names[param_id]);
|
|
|
|
|
|
|
|
|
|
// If does not compute gradient of that variable inside rnn, just
|
|
|
|
|
// continue
|
|
|
|
@ -406,27 +406,19 @@ class RecurrentGradOp : public RecurrentBase {
|
|
|
|
|
attrs["value"] = 0.0f;
|
|
|
|
|
|
|
|
|
|
auto zero_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"fill_constant", {}, {{"Out", {pg_names[prog_id]}}}, attrs);
|
|
|
|
|
"fill_constant", {}, {{"Out", {pg_names[param_id]}}}, attrs);
|
|
|
|
|
zero_op->Run(scope, dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto new_inside_name = cur_scope.Rename(inside_grad_name);
|
|
|
|
|
// sum gradient
|
|
|
|
|
auto *outside_var = scope.FindVar(pg_names[prog_id]);
|
|
|
|
|
PADDLE_ENFORCE(outside_var != nullptr);
|
|
|
|
|
auto &outside_tensor =
|
|
|
|
|
*outside_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
std::string result_var_name;
|
|
|
|
|
auto *local_result_var = cur_scope.Var(&result_var_name);
|
|
|
|
|
auto &local_result_tensor =
|
|
|
|
|
*local_result_var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
|
|
|
|
|
local_result_tensor.ShareDataWith(outside_tensor);
|
|
|
|
|
|
|
|
|
|
auto sum_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {result_var_name, inside_grad_name}}},
|
|
|
|
|
{{"Out", {result_var_name}}}, {});
|
|
|
|
|
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
|
|
|
|
|
{{"Out", {pg_names[param_id]}}}, {});
|
|
|
|
|
sum_op->Run(cur_scope, dev_ctx);
|
|
|
|
|
|
|
|
|
|
cur_scope.Rename(new_inside_name, inside_grad_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(5) << "Accumulate Parameter finished ";
|
|
|
|
|