|
|
|
@ -206,24 +206,23 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
|
|
|
|
|
details::GroupParamsAndGrads *group_params_grads) const {
|
|
|
|
|
SetGroupAccordingToLayers(var_nodes, params_grads, group_params_grads);
|
|
|
|
|
SetGroupAccordingToMemorySize(var_nodes, group_params_grads);
|
|
|
|
|
ReGroupByDtype(var_nodes, params_grads, group_params_grads);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetGroupAccordingToLayers(
|
|
|
|
|
const std::unordered_map<std::string, ir::Node *> &var_nodes,
|
|
|
|
|
const details::ParamsAndGrads ¶ms_grads,
|
|
|
|
|
details::GroupParamsAndGrads *group_params_grads) const {
|
|
|
|
|
using var_dtype = std::pair<std::string, proto::VarType::Type>;
|
|
|
|
|
std::map<var_dtype, size_t> var_idx;
|
|
|
|
|
std::map<std::string, size_t> var_idx;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < params_grads.size(); ++i) {
|
|
|
|
|
auto pos = params_grads[i].first.find_first_of(".");
|
|
|
|
|
|
|
|
|
|
auto dtype = GetDtypeOfVar(var_nodes, params_grads[i].second);
|
|
|
|
|
var_dtype var_key;
|
|
|
|
|
std::string var_key;
|
|
|
|
|
if (pos == std::string::npos) {
|
|
|
|
|
var_key = std::make_pair(params_grads[i].first, dtype);
|
|
|
|
|
var_key = params_grads[i].first;
|
|
|
|
|
} else {
|
|
|
|
|
var_key = std::make_pair(params_grads[i].first.substr(0, pos), dtype);
|
|
|
|
|
var_key = params_grads[i].first.substr(0, pos);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t idx = 0;
|
|
|
|
@ -289,9 +288,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
|
|
|
|
|
local_group_params_grads.emplace_back();
|
|
|
|
|
auto &group_p_g = local_group_params_grads.back();
|
|
|
|
|
|
|
|
|
|
auto &grad_name = group_params_grads->at(j).front().second;
|
|
|
|
|
auto var_type = GetDtypeOfVar(var_nodes, grad_name);
|
|
|
|
|
|
|
|
|
|
size_t local_group_memory_size = 0;
|
|
|
|
|
while (j < group_params_grads->size()) {
|
|
|
|
|
std::for_each(
|
|
|
|
@ -330,12 +326,6 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
|
|
|
|
|
group_memory_size) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto next_var_type =
|
|
|
|
|
GetDtypeOfVar(var_nodes, group_params_grads->at(j).front().second);
|
|
|
|
|
if (next_var_type != var_type) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -348,6 +338,55 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReGroupByDtype(
|
|
|
|
|
const std::unordered_map<std::string, ir::Node *> &var_nodes,
|
|
|
|
|
const details::ParamsAndGrads ¶ms_grads,
|
|
|
|
|
details::GroupParamsAndGrads *group_params_grads) const {
|
|
|
|
|
if (IsUnifiedDtype(params_grads, var_nodes)) {
|
|
|
|
|
VLOG(1) << "needn't regroup fusion params_grads";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
details::GroupParamsAndGrads new_group_params_grads;
|
|
|
|
|
|
|
|
|
|
for (auto &group_p_g : *group_params_grads) {
|
|
|
|
|
std::map<proto::VarType::Type, size_t> type_idx;
|
|
|
|
|
details::GroupParamsAndGrads local_group_params_grads;
|
|
|
|
|
|
|
|
|
|
for (auto &p_g : group_p_g) {
|
|
|
|
|
auto dtype = GetDtypeOfVar(var_nodes, p_g.second);
|
|
|
|
|
|
|
|
|
|
size_t idx = 0;
|
|
|
|
|
auto var_idx_iter = type_idx.find(dtype);
|
|
|
|
|
if (var_idx_iter != type_idx.end()) {
|
|
|
|
|
idx = var_idx_iter->second;
|
|
|
|
|
} else {
|
|
|
|
|
local_group_params_grads.emplace_back();
|
|
|
|
|
idx = local_group_params_grads.size() - 1;
|
|
|
|
|
type_idx[dtype] = idx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &local = local_group_params_grads.at(idx);
|
|
|
|
|
local.emplace_back(p_g);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VLOG(10) << "local_group_params_grads size:"
|
|
|
|
|
<< local_group_params_grads.size();
|
|
|
|
|
new_group_params_grads.insert(new_group_params_grads.end(),
|
|
|
|
|
local_group_params_grads.begin(),
|
|
|
|
|
local_group_params_grads.end());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::swap(*group_params_grads, new_group_params_grads);
|
|
|
|
|
|
|
|
|
|
if (VLOG_IS_ON(10)) {
|
|
|
|
|
VLOG(10) << string::Sprintf("ReGroupByDtype(memory_size: %f MB, %u):",
|
|
|
|
|
GetFuseParameterMemorySize(),
|
|
|
|
|
GetFuseParameterGroupsSize());
|
|
|
|
|
PrintGroupInfo(var_nodes, group_params_grads);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
proto::VarType::Type GetDtypeOfVar(
|
|
|
|
|
const std::unordered_map<std::string, Node *> &var_nodes,
|
|
|
|
|
const std::string &name) const {
|
|
|
|
|