|
|
|
@ -121,8 +121,8 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
for (size_t i = 0; i < outside_og_names.size(); ++i) {
|
|
|
|
|
auto outside_og_name = outside_og_names[i];
|
|
|
|
|
auto inside_og_name = inside_og_names[i];
|
|
|
|
|
VLOG(10) << "Linking outside " << outside_og_name << " --> inside "
|
|
|
|
|
<< inside_og_name;
|
|
|
|
|
VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
|
|
|
|
|
<< inside_og_name;
|
|
|
|
|
auto &og_outside =
|
|
|
|
|
detail::Ref(scope.FindVar(outside_og_name),
|
|
|
|
|
"Cannot find Outside Gradient %s", outside_og_name);
|
|
|
|
@ -141,11 +141,11 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
|
|
|
|
|
auto &inside_array =
|
|
|
|
|
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
|
|
|
|
|
VLOG(10) << outside_og_name << " size = " << outside_array.size();
|
|
|
|
|
VLOG(8) << outside_og_name << " size = " << outside_array.size();
|
|
|
|
|
inside_array.resize(outside_array.size());
|
|
|
|
|
|
|
|
|
|
for (size_t j = 0; j < inside_array.size(); ++j) {
|
|
|
|
|
VLOG(10) << j << " " << outside_array[j].numel();
|
|
|
|
|
VLOG(8) << j << " " << outside_array[j].numel();
|
|
|
|
|
if (outside_array[j].numel() != 0) {
|
|
|
|
|
inside_array[j].set_lod(outside_array[j].lod());
|
|
|
|
|
inside_array[j].ShareDataWith(outside_array[j]);
|
|
|
|
@ -187,10 +187,14 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
|
|
|
|
|
attrs["value"] = 0.0f;
|
|
|
|
|
|
|
|
|
|
auto var_name = pg_names[param_id];
|
|
|
|
|
auto zero_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"fill_constant", framework::VariableNameMap{},
|
|
|
|
|
{{"Out", {pg_names[param_id]}}}, attrs);
|
|
|
|
|
{{"Out", {var_name}}}, attrs);
|
|
|
|
|
zero_op->Run(scope, dev_place);
|
|
|
|
|
scope.FindVar(var_name)
|
|
|
|
|
->GetMutable<framework::LoDTensor>()
|
|
|
|
|
->set_lod(inside_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -231,7 +235,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
|
|
|
|
|
for (auto &each_ig : igs) {
|
|
|
|
|
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
|
|
|
|
|
VLOG(10) << "Ignore " << each_ig;
|
|
|
|
|
VLOG(8) << "Ignore " << each_ig;
|
|
|
|
|
each_ig = framework::kEmptyVarName;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|