|
|
|
@ -129,9 +129,6 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
auto &og_inside =
|
|
|
|
|
detail::Ref(cur_scope.Var(inside_og_name),
|
|
|
|
|
"Cannot find inside gradient %s", inside_og_name);
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "OG " << outside_og_name << " Type is "
|
|
|
|
|
<< og_outside.Type().name();
|
|
|
|
|
if (og_outside.Type().hash_code() ==
|
|
|
|
|
typeid(framework::LoDTensor).hash_code()) {
|
|
|
|
|
auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
|
|
|
|
@ -148,6 +145,7 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
inside_array.resize(outside_array.size());
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < inside_array.size(); ++j) {
|
|
|
|
|
VLOG(10) << j << " " << outside_array[j].numel();
|
|
|
|
|
if (outside_array[j].numel() != 0) {
|
|
|
|
|
inside_array[j].set_lod(outside_array[j].lod());
|
|
|
|
|
inside_array[j].ShareDataWith(outside_array[j]);
|
|
|
|
@ -200,17 +198,6 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
auto sum_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
|
|
|
|
|
{{"Out", {pg_names[param_id]}}}, framework::AttributeMap{});
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "Accumulate the gradient of " << pg_names[param_id];
|
|
|
|
|
|
|
|
|
|
if (pg_names[param_id] == "W@GRAD") {
|
|
|
|
|
auto &w_g = detail::Ref(cur_scope.FindVar(new_inside_name))
|
|
|
|
|
.Get<framework::LoDTensor>();
|
|
|
|
|
VLOG(10) << "W_G is" << w_g.data<float>()[0];
|
|
|
|
|
} else {
|
|
|
|
|
VLOG(10) << pg_names[param_id];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sum_op->Run(cur_scope, dev_place);
|
|
|
|
|
cur_scope.Rename(new_inside_name, inside_grad_name);
|
|
|
|
|
}
|
|
|
|
|