|
|
|
@ -147,6 +147,18 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
|
|
|
|
|
vars.emplace(node->Var()->Name(), node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Set Gradients as Persistable to prevent this var becoming reusable.
|
|
|
|
|
for (auto &grad_var_name : grads) {
|
|
|
|
|
auto iter = vars.find(grad_var_name);
|
|
|
|
|
PADDLE_ENFORCE(iter != vars.end());
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
|
|
|
|
|
PADDLE_ENFORCE(iter->second->Var()->GetType() == proto::VarType::LOD_TENSOR,
|
|
|
|
|
"Currently the gradient type only should be LoDTensor when "
|
|
|
|
|
"fusing optimizer ops.");
|
|
|
|
|
iter->second->Var()->SetPersistable(true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Init Grads
|
|
|
|
|
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
|
|
|
|
|
auto &scope = *it;
|
|
|
|
@ -154,13 +166,10 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
|
|
|
|
|
PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr,
|
|
|
|
|
"%s has existed in scope.", fused_grad_name);
|
|
|
|
|
scope->Var(fused_grad_name)->GetMutable<LoDTensor>();
|
|
|
|
|
|
|
|
|
|
for (auto &grad_var_name : grads) {
|
|
|
|
|
auto iter = vars.find(grad_var_name);
|
|
|
|
|
PADDLE_ENFORCE(iter != vars.end());
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
|
|
|
|
|
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
|
|
|
|
|
proto::VarType::LOD_TENSOR);
|
|
|
|
|
scope->Var(grad_var_name)->GetMutable<LoDTensor>();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|