|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/details/build_strategy.h"
|
|
|
|
@ -84,16 +85,19 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (params_grads.size() == 0) {
|
|
|
|
|
LOG(WARNING) << "Doesn't find gradients";
|
|
|
|
|
LOG(INFO) << "Doesn't find gradients";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, ir::Node *> vars;
|
|
|
|
|
std::unordered_map<std::string, ir::Node *> var_name2node;
|
|
|
|
|
std::unordered_map<std::string, std::unordered_set<ir::Node *>>
|
|
|
|
|
var_name2node_set;
|
|
|
|
|
for (ir::Node *node : result.Nodes()) {
|
|
|
|
|
if (node->IsVar() && node->Var()) {
|
|
|
|
|
// Note: The graph may have the same name node. For example, parameter
|
|
|
|
|
// is the input of operator and it also is the output of optimizer;
|
|
|
|
|
vars.emplace(node->Var()->Name(), node);
|
|
|
|
|
var_name2node.emplace(node->Var()->Name(), node);
|
|
|
|
|
var_name2node_set[node->Var()->Name()].emplace(node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -101,7 +105,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
|
|
|
|
|
result.Get<details::GroupGradsAndParams>(details::kGroupGradsAndParams);
|
|
|
|
|
|
|
|
|
|
// Note: the order of params_grads may be changed by SetGroupGradsAndParams.
|
|
|
|
|
SetGroupGradsAndParams(vars, params_grads, &group_grads_params);
|
|
|
|
|
SetGroupGradsAndParams(var_name2node, params_grads, &group_grads_params);
|
|
|
|
|
|
|
|
|
|
params_grads.clear();
|
|
|
|
|
for (auto &group_p_g : group_grads_params) {
|
|
|
|
@ -116,9 +120,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
|
|
|
|
|
auto dtype = kDefaultDtype;
|
|
|
|
|
for (auto &p_g : params_grads) {
|
|
|
|
|
// Get gradient var
|
|
|
|
|
auto iter = vars.find(p_g.second);
|
|
|
|
|
PADDLE_ENFORCE(iter != vars.end(), "%s is not found.", p_g.second);
|
|
|
|
|
iter->second->Var()->SetPersistable(true);
|
|
|
|
|
auto iter = var_name2node.find(p_g.second);
|
|
|
|
|
PADDLE_ENFORCE(iter != var_name2node.end(), "%s is not found.",
|
|
|
|
|
p_g.second);
|
|
|
|
|
// Set persistable
|
|
|
|
|
auto same_nodes = var_name2node_set.find(p_g.second);
|
|
|
|
|
PADDLE_ENFORCE(same_nodes != var_name2node_set.end(), "%s is not found.",
|
|
|
|
|
p_g.second);
|
|
|
|
|
for (auto it : same_nodes->second) {
|
|
|
|
|
it->Var()->SetPersistable(true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(IsSupportedVarType(iter->second->Var()->GetType()));
|
|
|
|
|
|
|
|
|
@ -151,7 +162,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
|
|
|
|
|
"%s is duplicate in FusedVars.", fused_var_name);
|
|
|
|
|
fused_var_set.insert(fused_var_name);
|
|
|
|
|
|
|
|
|
|
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, vars,
|
|
|
|
|
InitFusedVarsAndAllocSpaceForVars(places, local_scopes, var_name2node,
|
|
|
|
|
fused_var_name, params_grads);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|