Optimize the implementation of while_op again, for cases when is_test is true. (#16359)

test=develop
move-code
Yiqun Liu 6 years ago committed by GitHub
parent c34b24ede7
commit 98802e1f75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,6 +51,7 @@ class WhileOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
@ -70,13 +71,34 @@ class WhileOp : public framework::OperatorBase {
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
while (cond.data<bool>()[0]) {
if (!is_test) {
while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false, true,
true);
}
} else {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
if (is_test) {
scope.DeleteScope(&current_scope);
executor.CreateVariables(*program, &current_scope, block->ID());
while (cond.data<bool>()[0]) {
for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name);
if (var->IsType<framework::LoDTensor>()) {
// Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<framework::LoDTensor>();
framework::LoD empty_lod;
t->set_lod(empty_lod);
} else if (var->IsType<framework::LoDTensorArray>()) {
// Clear elements of all tensor arrays.
auto *t = var->GetMutable<framework::LoDTensorArray>();
t->clear();
}
}
executor.RunPreparedContext(ctx.get(), &current_scope, false, false,
false);
}
scope.DeleteScope(&current_scope);
}
}
};

Loading…
Cancel
Save