|
|
|
@ -194,14 +194,27 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto check_var_no_nan = [](const framework::Scope &scope,
|
|
|
|
|
const std::string &var_name) {
|
|
|
|
|
auto *var = scope.FindVar(var_name);
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
VLOG(10) << "Checking " << var_name;
|
|
|
|
|
PADDLE_ENFORCE(!framework::HasNAN(var->Get<framework::LoDTensor>()),
|
|
|
|
|
"%s has NAN", var_name);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
check_var_no_nan(cur_scope, inside_grad_name);
|
|
|
|
|
auto new_inside_name = cur_scope.Rename(inside_grad_name);
|
|
|
|
|
check_var_no_nan(cur_scope, new_inside_name);
|
|
|
|
|
auto sum_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
|
|
|
|
|
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
|
|
|
|
|
sum_op->Run(cur_scope, dev_place);
|
|
|
|
|
check_var_no_nan(cur_scope, pg_names[param_id]);
|
|
|
|
|
cur_scope.Rename(new_inside_name, inside_grad_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(1) << "Complete WhileOpGrad";
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|