|
|
|
@ -58,7 +58,6 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
|
const platform::Place &dev_place) const override {
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
|
|
|
|
|
|
|
|
|
|
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
|
|
|
|
|
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
|
|
|
|
|
|
|
|
|
@ -78,35 +77,15 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
|
|
|
|
|
|
|
|
|
|
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
|
|
|
|
|
if (!is_test) {
|
|
|
|
|
while (cond.data<bool>()[0]) {
|
|
|
|
|
auto ¤t_scope = scope.NewScope();
|
|
|
|
|
step_scopes->push_back(¤t_scope);
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true,
|
|
|
|
|
true);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto ¤t_scope = scope.NewScope();
|
|
|
|
|
executor.CreateVariables(*program, ¤t_scope, block->ID());
|
|
|
|
|
while (cond.data<bool>()[0]) {
|
|
|
|
|
for (auto &name : current_scope.LocalVarNames()) {
|
|
|
|
|
auto *var = current_scope.Var(name);
|
|
|
|
|
framework::LoD empty_lod;
|
|
|
|
|
if (var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
// Clear all lod information for all lod_tensors.
|
|
|
|
|
auto *t = var->GetMutable<framework::LoDTensor>();
|
|
|
|
|
t->set_lod(empty_lod);
|
|
|
|
|
} else if (var->IsType<framework::LoDRankTable>()) {
|
|
|
|
|
auto *t = var->GetMutable<framework::LoDRankTable>();
|
|
|
|
|
t->Reset(empty_lod, 0);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), ¤t_scope, false, false,
|
|
|
|
|
false);
|
|
|
|
|
}
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, true);
|
|
|
|
|
if (is_test) {
|
|
|
|
|
scope.DeleteScope(¤t_scope);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|