|
|
|
@ -56,7 +56,9 @@ class ConditionalBlockOp : public ConditionalOp {
|
|
|
|
|
|
|
|
|
|
if (need_run) {
|
|
|
|
|
auto *scope_var = scope.FindVar(Output(ConditionalOp::kScope));
|
|
|
|
|
PADDLE_ENFORCE(scope_var != nullptr, "Must set scope");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
scope_var, platform::errors::PreconditionNotMet(
|
|
|
|
|
"Scope must be set in conditional_block_op."));
|
|
|
|
|
auto *scopes = scope_var->GetMutable<std::vector<framework::Scope *>>();
|
|
|
|
|
scopes->resize(1);
|
|
|
|
|
scopes->front() = &scope.NewScope();
|
|
|
|
@ -79,7 +81,7 @@ class ConditionalBlockInferShape : public framework::InferShapeBase {
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(context->HasInputs(ConditionalOp::kCondition), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"conditional_block_op must have condition input"));
|
|
|
|
|
"conditional_block_op must have condition input."));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -116,13 +118,13 @@ class ConditionalBlockGradOp : public ConditionalOp {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto *scope_var = scope.FindVar(Input(ConditionalOp::kScope));
|
|
|
|
|
PADDLE_ENFORCE_NE(scope_var, nullptr,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Scope must be set in conditional block op"));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
scope_var, platform::errors::PreconditionNotMet(
|
|
|
|
|
"Scope must be set in conditional block op."));
|
|
|
|
|
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
|
|
|
|
|
PADDLE_ENFORCE_GT(scopes.size(), 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Scope must be set in conditional block op"));
|
|
|
|
|
"Scope must be set in conditional block op."));
|
|
|
|
|
framework::Scope &cur_scope = *scopes[0];
|
|
|
|
|
|
|
|
|
|
framework::Executor exec(dev_place);
|
|
|
|
@ -192,7 +194,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
|
|
|
|
|
PADDLE_ENFORCE_EQ(outside_var->IsType<framework::LoDTensor>(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Type of outside_var %s is NOT LoDTensor, which "
|
|
|
|
|
"doesn't match input_var %s",
|
|
|
|
|
"doesn't match input_var %s.",
|
|
|
|
|
outside_grad_name, input_name));
|
|
|
|
|
AssignZeroToOutsideTensor(
|
|
|
|
|
place, scope, input_var->Get<framework::LoDTensor>(),
|
|
|
|
@ -202,7 +204,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
|
|
|
|
|
true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Type of outside_var %s is NOT LoDTensorArray, "
|
|
|
|
|
"which doesn't match input_var %s",
|
|
|
|
|
"which doesn't match input_var %s.",
|
|
|
|
|
outside_grad_name, input_name));
|
|
|
|
|
const auto &input_tensors = input_var->Get<framework::LoDTensorArray>();
|
|
|
|
|
auto *outside_tensors =
|
|
|
|
@ -210,7 +212,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
|
|
|
|
|
PADDLE_ENFORCE_EQ(input_tensors.size(), outside_tensors->size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"LoDTensorArray outside_var %s doen't have same "
|
|
|
|
|
"size as input_var %s",
|
|
|
|
|
"size as input_var %s.",
|
|
|
|
|
outside_grad_name, input_name));
|
|
|
|
|
for (size_t j = 0; j < input_tensors.size(); ++j) {
|
|
|
|
|
AssignZeroToOutsideTensor(place, scope, input_tensors[j],
|
|
|
|
@ -220,7 +222,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
|
|
|
|
|
// TODO(huihuangzheng): add support for SelectedRows
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"Conditional block grad op doesn't support non-LoDTensor output "
|
|
|
|
|
"now"));
|
|
|
|
|
"now."));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -245,7 +247,10 @@ class ConditionalBlockGradOp : public ConditionalOp {
|
|
|
|
|
class ConditionalBlockGradInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE(context->HasInputs(ConditionalOp::kCondition));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
context->HasInputs(ConditionalOp::kCondition), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Condition must be set in conditional_block_grad_op."));
|
|
|
|
|
if (context->HasInputs(ConditionalOp::kInputs) &&
|
|
|
|
|
context->HasOutputs(framework::GradVarName(ConditionalOp::kInputs))) {
|
|
|
|
|
context->SetOutputsDim(framework::GradVarName(ConditionalOp::kInputs),
|
|
|
|
|