fix fuse optimizer ops (#17102)

test=develop
feature/fluid_trt_int8
chengduo 6 years ago committed by GitHub
parent 258e000be6
commit 794a195881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -147,6 +147,18 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
vars.emplace(node->Var()->Name(), node);
}
}
// Set Gradients as Persistable to prevent this var becoming reusable.
for (auto &grad_var_name : grads) {
auto iter = vars.find(grad_var_name);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE(iter->second->Var()->GetType() == proto::VarType::LOD_TENSOR,
"Currently the gradient type only should be LoDTensor when "
"fusing optimizer ops.");
iter->second->Var()->SetPersistable(true);
}
// Init Grads
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
@ -154,13 +166,10 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr,
"%s has existed in scope.", fused_grad_name);
scope->Var(fused_grad_name)->GetMutable<LoDTensor>();
for (auto &grad_var_name : grads) {
auto iter = vars.find(grad_var_name);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(grad_var_name)->GetMutable<LoDTensor>();
}
}

Loading…
Cancel
Save