|
|
|
@ -12,6 +12,8 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
|
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/executor.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/operator.h"
|
|
|
|
@ -70,6 +72,23 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
|
|
|
|
|
|
|
|
|
|
auto *program = block->Program();
|
|
|
|
|
bool is_test = Attr<bool>("is_test");
|
|
|
|
|
|
|
|
|
|
std::set<std::string> no_copy_var_names;
|
|
|
|
|
if (!is_test) {
|
|
|
|
|
const std::vector<framework::OpDesc *> &all_ops = block->AllOps();
|
|
|
|
|
for (const framework::OpDesc *op : all_ops) {
|
|
|
|
|
const framework::VariableNameMap &input_var_names = op->Inputs();
|
|
|
|
|
const framework::VariableNameMap &output_var_names = op->Outputs();
|
|
|
|
|
for (auto &ipt : input_var_names) {
|
|
|
|
|
for (const std::string &var_name : ipt.second) {
|
|
|
|
|
if (StrInVaraiableNameMap(var_name, output_var_names)) {
|
|
|
|
|
no_copy_var_names.insert(var_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto step_scopes =
|
|
|
|
|
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
|
|
|
|
@ -89,7 +108,6 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
"The Output(StepScope) of WhileOp should be empty."));
|
|
|
|
|
|
|
|
|
|
bool cond_data = GetCondData(cond);
|
|
|
|
|
bool is_test = Attr<bool>("is_test");
|
|
|
|
|
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
|
|
|
|
|
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
|
|
|
|
|
|
|
|
|
@ -98,8 +116,32 @@ class WhileOp : public framework::OperatorBase {
|
|
|
|
|
while (cond_data) {
|
|
|
|
|
auto ¤t_scope = scope.NewScope();
|
|
|
|
|
step_scopes->push_back(¤t_scope);
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> rename_vars;
|
|
|
|
|
for (const std::string &input_var_name : Inputs(kX)) {
|
|
|
|
|
if (no_copy_var_names.find(input_var_name) ==
|
|
|
|
|
no_copy_var_names.end()) {
|
|
|
|
|
std::string input_var_rename = input_var_name + kSuffix;
|
|
|
|
|
framework::Variable *input_var = scope.FindVar(input_var_name);
|
|
|
|
|
if (input_var->IsType<framework::LoDTensor>()) {
|
|
|
|
|
rename_vars.push_back(input_var_rename);
|
|
|
|
|
auto input_var_tensor = input_var->Get<LoDTensor>();
|
|
|
|
|
auto *rename_input_var_tensor =
|
|
|
|
|
current_scope.Var(input_var_rename)->GetMutable<LoDTensor>();
|
|
|
|
|
framework::TensorCopy(input_var_tensor, dev_place,
|
|
|
|
|
rename_input_var_tensor);
|
|
|
|
|
rename_input_var_tensor->set_lod(input_var_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true,
|
|
|
|
|
true);
|
|
|
|
|
|
|
|
|
|
for (auto &var_rename : rename_vars) {
|
|
|
|
|
std::string input_var_name =
|
|
|
|
|
var_rename.substr(0, var_rename.size() - strlen(kSuffix));
|
|
|
|
|
current_scope.Rename(var_rename, input_var_name);
|
|
|
|
|
}
|
|
|
|
|
cond_data =
|
|
|
|
|
GetCondData(scope.FindVar(Input(kCondition))->Get<LoDTensor>());
|
|
|
|
|
}
|
|
|
|
@ -312,6 +354,10 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
// continue;
|
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
auto var_iter =
|
|
|
|
|
std::find(outside_og_names.begin(), outside_og_names.end(),
|
|
|
|
|
pg_ig_names[param_id]);
|
|
|
|
|
|
|
|
|
|
// zero gradient variable in step 0
|
|
|
|
|
if (cur_scope_iter == step_scopes->rbegin()) {
|
|
|
|
|
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
|
|
|
|
@ -326,7 +372,8 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
"or LoDTensor, but the received var[%s] is %s.",
|
|
|
|
|
inside_grad_name, framework::ToTypeName(var->Type())));
|
|
|
|
|
|
|
|
|
|
if (var->IsType<LoDTensor>()) {
|
|
|
|
|
if ((var_iter == outside_og_names.end()) &&
|
|
|
|
|
var->IsType<LoDTensor>()) {
|
|
|
|
|
auto &inside_tensor = var->Get<framework::LoDTensor>();
|
|
|
|
|
framework::AttributeMap attrs;
|
|
|
|
|
attrs["dtype"] = inside_tensor.type();
|
|
|
|
@ -343,13 +390,18 @@ class WhileGradOp : public framework::OperatorBase {
|
|
|
|
|
->set_lod(inside_tensor.lod());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto new_inside_name = cur_scope.Rename(inside_grad_name);
|
|
|
|
|
auto sum_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
|
|
|
|
|
{{"Out", {pg_ig_names[param_id]}}},
|
|
|
|
|
framework::AttributeMap{{"use_mkldnn", {false}}});
|
|
|
|
|
sum_op->Run(cur_scope, dev_place);
|
|
|
|
|
cur_scope.Rename(new_inside_name, inside_grad_name);
|
|
|
|
|
auto var_outside = scope.FindVar(pg_ig_names[param_id]);
|
|
|
|
|
if ((var_iter == outside_og_names.end()) ||
|
|
|
|
|
((var_iter != outside_og_names.end()) &&
|
|
|
|
|
var_outside->IsType<framework::LoDTensorArray>())) {
|
|
|
|
|
auto new_inside_name = cur_scope.Rename(inside_grad_name);
|
|
|
|
|
auto sum_op = framework::OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
|
|
|
|
|
{{"Out", {pg_ig_names[param_id]}}},
|
|
|
|
|
framework::AttributeMap{{"use_mkldnn", {false}}});
|
|
|
|
|
sum_op->Run(cur_scope, dev_place);
|
|
|
|
|
cur_scope.Rename(new_inside_name, inside_grad_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
dev_ctx.Wait();
|
|
|
|
|
const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
|
|
|
|
|