|
|
|
@ -224,10 +224,12 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
if (cur_scope_iter == step_scopes->rbegin()) {
|
|
|
|
|
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name);
|
|
|
|
|
PADDLE_ENFORCE(var->IsType<framework::LoDTensorArray>() ||
|
|
|
|
|
var->IsType<LoDTensor>(),
|
|
|
|
|
"Currently the type of var only can be LoDTensorArray "
|
|
|
|
|
"or LoDTensor.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
var->IsType<framework::LoDTensorArray>() ||
|
|
|
|
|
var->IsType<LoDTensor>(),
|
|
|
|
|
"Currently the type of var only can be LoDTensorArray, "
|
|
|
|
|
"or LoDTensor, but the received var[%s] is %s.",
|
|
|
|
|
inside_grad_name, var->Type().name());
|
|
|
|
|
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
auto &inside_tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|