|
|
|
@ -62,7 +62,7 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
|
|
|
|
|
auto step_scopes =
|
|
|
|
|
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(step_scopes->size(), 0, "The StepScope should be empty.");
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
|
|
|
|
|
"Condition of while op must in CPU memory.");
|
|
|
|
|
|
|
|
|
@ -197,17 +197,22 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
inside_tensor.set_lod(outside_tensor.lod());
|
|
|
|
|
inside_tensor.ShareDataWith(outside_tensor);
|
|
|
|
|
} else if (og_outside.IsType<framework::LoDTensorArray>()) {
|
|
|
|
|
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
|
|
|
|
|
auto outside_array =
|
|
|
|
|
og_outside.GetMutable<framework::LoDTensorArray>();
|
|
|
|
|
auto &inside_array =
|
|
|
|
|
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
|
|
|
|
|
VLOG(8) << outside_og_name << " size = " << outside_array.size();
|
|
|
|
|
inside_array.resize(outside_array.size());
|
|
|
|
|
inside_array.clear();
|
|
|
|
|
inside_array.resize(outside_array->size());
|
|
|
|
|
VLOG(8) << outside_og_name << " size = " << outside_array->size();
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < inside_array.size(); ++j) {
|
|
|
|
|
VLOG(8) << j << " " << outside_array[j].numel();
|
|
|
|
|
if (outside_array[j].numel() != 0) {
|
|
|
|
|
inside_array[j].set_lod(outside_array[j].lod());
|
|
|
|
|
inside_array[j].ShareDataWith(outside_array[j]);
|
|
|
|
|
if (!outside_array->at(j).IsInitialized()) {
|
|
|
|
|
outside_array->at(j).Resize({0});
|
|
|
|
|
}
|
|
|
|
|
VLOG(8) << j << " " << outside_array->at(j).numel();
|
|
|
|
|
if (outside_array->at(j).numel() != 0) {
|
|
|
|
|
inside_array[j].set_lod(outside_array->at(j).lod());
|
|
|
|
|
inside_array[j].ShareDataWith(outside_array->at(j));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0);
|
|
|
|
|
}
|
|
|
|
@ -300,6 +305,7 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
|
|
|
|
|
}
|
|
|
|
|
step_scopes->clear();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|