|
|
|
@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
|
|
|
|
|
"Condition of while op must in CPU memory.");
|
|
|
|
|
|
|
|
|
|
auto ctx = executor.Prepare(*program, block->ID());
|
|
|
|
|
while (cond.data<bool>()[0]) {
|
|
|
|
|
auto ¤t_scope = scope.NewScope();
|
|
|
|
|
step_scopes->push_back(¤t_scope);
|
|
|
|
|
|
|
|
|
|
executor.Run(*program, ¤t_scope, block->ID(),
|
|
|
|
|
false /*create_local_scope*/);
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), ¤t_scope, false);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
framework::Executor executor(dev_place);
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
|
auto *program = block->Program();
|
|
|
|
|
auto ctx = executor.Prepare(*program, block->ID());
|
|
|
|
|
|
|
|
|
|
auto *step_scopes =
|
|
|
|
|
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
|
|
|
|
@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
executor.Run(*program, *cur_scope_iter, block->ID(), false);
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false);
|
|
|
|
|
|
|
|
|
|
auto &pg_names = Outputs(kXGRAD);
|
|
|
|
|
auto &p_names = Inputs(kX);
|
|
|
|
|