|
|
|
@ -57,8 +57,10 @@ class ConditionalBlockOp : public ConditionalOp {
|
|
|
|
|
if (need_run) {
|
|
|
|
|
auto *scope_var = scope.FindVar(Output(ConditionalOp::kScope));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
scope_var, platform::errors::PreconditionNotMet(
|
|
|
|
|
"Scope must be set in conditional_block_op."));
|
|
|
|
|
scope_var,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Expect Scope variable to be set in conditional_block_op, but "
|
|
|
|
|
"got a null Scope variable. Please set the Scope variable."));
|
|
|
|
|
auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
|
|
|
|
|
scopes->resize(1);
|
|
|
|
|
scopes->front() = &scope.NewScope();
|
|
|
|
@ -119,12 +121,16 @@ class ConditionalBlockGradOp : public ConditionalOp {
|
|
|
|
|
|
|
|
|
|
auto *scope_var = scope.FindVar(Input(ConditionalOp::kScope));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
scope_var, platform::errors::PreconditionNotMet(
|
|
|
|
|
"Scope must be set in conditional block op."));
|
|
|
|
|
scope_var,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Expect Scope variable to be set in conditional_block_op, but "
|
|
|
|
|
"got a null Scope variable. Please set the Scope variable."));
|
|
|
|
|
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
|
|
|
|
|
PADDLE_ENFORCE_GT(scopes.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Scope must be set in conditional block op."));
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
scopes.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Expect Scope variable contains at least 1 scope, but got: %d",
|
|
|
|
|
scopes.size()));
|
|
|
|
|
framework::Scope &cur_scope = *scopes[0];
|
|
|
|
|
|
|
|
|
|
framework::Executor exec(dev_place);
|
|
|
|
|