|
|
|
@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp {
|
|
|
|
|
scopes->front() = &scope.NewScope();
|
|
|
|
|
auto &cur_scope = *scopes->front();
|
|
|
|
|
|
|
|
|
|
auto *block = Attr<framework::BlockDescBind *>("block");
|
|
|
|
|
auto *block = Attr<framework::BlockDescBind *>("sub_block");
|
|
|
|
|
framework::Executor exec(dev_ctx);
|
|
|
|
|
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
|
|
|
|
|
}
|
|
|
|
@ -88,7 +88,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"unify the conditional block, rnn and while op, the type of "
|
|
|
|
|
"scope is std::vector<Scope*>");
|
|
|
|
|
AddAttr<framework::BlockDescBind *>(
|
|
|
|
|
"block", "The step block of conditional block operator");
|
|
|
|
|
"sub_block", "The step block of conditional block operator");
|
|
|
|
|
AddComment(R"DOC(Conditional block operator
|
|
|
|
|
|
|
|
|
|
Run the sub-block if X is not empty. Params is the other inputs and Out is the
|
|
|
|
@ -117,7 +117,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
|
|
|
|
|
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
|
|
|
|
|
framework::Scope &cur_scope = *scopes[0];
|
|
|
|
|
|
|
|
|
|
auto *block = Attr<framework::BlockDescBind *>("block");
|
|
|
|
|
auto *block = Attr<framework::BlockDescBind *>("sub_block");
|
|
|
|
|
framework::Executor exec(dev_ctx);
|
|
|
|
|
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
|
|
|
|
|
|
|
|
|
@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
grad_op->SetInput("Scope", Output("Scope"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params"));
|
|
|
|
|
grad_op->SetBlockAttr("block", *this->grad_block_[0]);
|
|
|
|
|
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
|
|
|
|
|
return std::unique_ptr<framework::OpDescBind>(grad_op);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|