|
|
|
@ -444,6 +444,26 @@ std::vector<std::map<std::string, std::vector<int>>> DeseralizeBatchVarShapes(
|
|
|
|
|
return batch_shapes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Replace the -1 in shape to a real number to fake the shape.
|
|
|
|
|
std::vector<std::map<std::string, std::vector<int>>> FakeBatchVarShapes(
|
|
|
|
|
const framework::ProgramDesc& program) {
|
|
|
|
|
std::vector<std::map<std::string, std::vector<int>>> res;
|
|
|
|
|
res.emplace_back();
|
|
|
|
|
auto& record = res.front();
|
|
|
|
|
const int fake_batch_size = 3;
|
|
|
|
|
for (auto* var : program.Block(0).AllVars()) {
|
|
|
|
|
if (var->GetType() ==
|
|
|
|
|
framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
|
|
|
|
|
auto shape = var->GetShape();
|
|
|
|
|
for (auto& v : shape) {
|
|
|
|
|
if (v < 0) v = fake_batch_size;
|
|
|
|
|
}
|
|
|
|
|
record[var->Name()].assign(shape.begin(), shape.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Calculate the average dim of each tensor from the batch shape cache.
|
|
|
|
|
std::unordered_map<std::string, size_t> GetBatchAverageSize(
|
|
|
|
|
const std::vector<std::map<std::string, std::vector<int>>>& batches) {
|
|
|
|
@ -478,6 +498,7 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesByBatchSize(
|
|
|
|
|
std::unordered_map<std::string, std::stringstream> var_batchsize_hashes;
|
|
|
|
|
for (auto& batch : batches) {
|
|
|
|
|
for (auto& ele : batch) {
|
|
|
|
|
PADDLE_ENFORCE(!ele.second.empty());
|
|
|
|
|
int batch_size = ele.second.front();
|
|
|
|
|
// TODO(Superjomn) might consume large memory here, use combine hash.
|
|
|
|
|
var_batchsize_hashes[ele.first] << batch_size;
|
|
|
|
@ -538,9 +559,21 @@ std::vector<std::unordered_set<std::string>> AnalysisBatchShapesBySimilarSize(
|
|
|
|
|
|
|
|
|
|
std::string MemoryOptimizePass::repr() const { return "memory optimize pass"; }
|
|
|
|
|
|
|
|
|
|
std::pair<size_t, size_t> GetRange(
|
|
|
|
|
const std::unordered_map<std::string, size_t>& ave_size) {
|
|
|
|
|
auto res = std::make_pair(std::numeric_limits<size_t>::max(),
|
|
|
|
|
std::numeric_limits<size_t>::min());
|
|
|
|
|
for (auto& item : ave_size) {
|
|
|
|
|
res.first = std::min(item.second, res.first);
|
|
|
|
|
res.second = std::max(item.second, res.second);
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemoryOptimizePass::RunImpl(Argument* argument) {
|
|
|
|
|
// When force update, should not optimize memory.
|
|
|
|
|
if (!argument->enable_memory_optim() || argument->memory_optim_force_update())
|
|
|
|
|
if (!argument->enable_memory_optim() ||
|
|
|
|
|
argument->static_memory_optim_force_update())
|
|
|
|
|
return;
|
|
|
|
|
graph_ = argument->main_graph_ptr();
|
|
|
|
|
|
|
|
|
@ -549,11 +582,26 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
|
|
|
|
|
argument->model_program_path_valid() ? argument->model_program_path()
|
|
|
|
|
: "");
|
|
|
|
|
VLOG(3) << "Load memory cache from " << path;
|
|
|
|
|
if (inference::IsFileExists(path)) {
|
|
|
|
|
VLOG(4) << "Performing memory optimize";
|
|
|
|
|
auto batches = DeseralizeBatchVarShapes(path);
|
|
|
|
|
std::vector<std::map<std::string, std::vector<int>>> batches;
|
|
|
|
|
|
|
|
|
|
if (argument->static_memory_optim() && inference::IsFileExists(path)) {
|
|
|
|
|
string::PrettyLogInfo("--- Performing static memory optimize");
|
|
|
|
|
batches = DeseralizeBatchVarShapes(path);
|
|
|
|
|
} else {
|
|
|
|
|
string::PrettyLogInfo("--- Performing dynamic memory optimize");
|
|
|
|
|
batches = FakeBatchVarShapes(argument->main_program());
|
|
|
|
|
}
|
|
|
|
|
auto var_batch_ave_size = GetBatchAverageSize(batches);
|
|
|
|
|
|
|
|
|
|
// Get min and max memory size.
|
|
|
|
|
const auto range = GetRange(var_batch_ave_size);
|
|
|
|
|
const int cluster_size = std::max(
|
|
|
|
|
static_cast<int>((range.second - range.first) / 100 /*cluster num*/),
|
|
|
|
|
1024);
|
|
|
|
|
const int cluster_size1 = std::max(
|
|
|
|
|
static_cast<int>((range.second - range.first) / 1000 /*cluster num*/),
|
|
|
|
|
1024);
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, Node*> tensor_nodes;
|
|
|
|
|
space_table_t space_table;
|
|
|
|
|
CollectVarMemorySize(var_batch_ave_size, &tensor_nodes, &space_table);
|
|
|
|
@ -564,6 +612,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
|
|
|
|
|
std::vector<std::function<MemoryAllocation()>> strategies;
|
|
|
|
|
|
|
|
|
|
for (int sort_kind = 0; sort_kind < 2; sort_kind++) {
|
|
|
|
|
if (argument->static_memory_optim()) {
|
|
|
|
|
// This strategy only make scene in static memory optimize.
|
|
|
|
|
strategies.emplace_back([&, sort_kind] {
|
|
|
|
|
auto clustered_vars_by_batch_size =
|
|
|
|
|
AnalysisBatchShapesByBatchSize(batches);
|
|
|
|
@ -572,22 +622,23 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
|
|
|
|
|
space_table, &reuse_table, sort_kind, &allocation);
|
|
|
|
|
return allocation;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
strategies.emplace_back([&, sort_kind] {
|
|
|
|
|
auto clustered_vars_by_ave_size = AnalysisBatchShapesBySimilarSize(
|
|
|
|
|
space_table, batches, 1024); // interval 1kb
|
|
|
|
|
auto clustered_vars_by_ave_size =
|
|
|
|
|
AnalysisBatchShapesBySimilarSize(space_table, batches, cluster_size);
|
|
|
|
|
MemoryAllocation allocation;
|
|
|
|
|
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size,
|
|
|
|
|
space_table, &reuse_table, sort_kind, &allocation);
|
|
|
|
|
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
|
|
|
|
|
&reuse_table, sort_kind, &allocation);
|
|
|
|
|
return allocation;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
strategies.emplace_back([&, sort_kind] {
|
|
|
|
|
auto clustered_vars_by_ave_size = AnalysisBatchShapesBySimilarSize(
|
|
|
|
|
space_table, batches, 1024 * 1024); // interval 1MB
|
|
|
|
|
auto clustered_vars_by_ave_size =
|
|
|
|
|
AnalysisBatchShapesBySimilarSize(space_table, batches, cluster_size1);
|
|
|
|
|
MemoryAllocation allocation;
|
|
|
|
|
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size,
|
|
|
|
|
space_table, &reuse_table, sort_kind, &allocation);
|
|
|
|
|
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
|
|
|
|
|
&reuse_table, sort_kind, &allocation);
|
|
|
|
|
return allocation;
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
@ -596,8 +647,8 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
|
|
|
|
|
space_table, batches,
|
|
|
|
|
std::numeric_limits<int>::max()); // no intervals
|
|
|
|
|
MemoryAllocation allocation;
|
|
|
|
|
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size,
|
|
|
|
|
space_table, &reuse_table, sort_kind, &allocation);
|
|
|
|
|
MakeReusePlan(clustered_vars_by_ave_size, var_batch_ave_size, space_table,
|
|
|
|
|
&reuse_table, sort_kind, &allocation);
|
|
|
|
|
return allocation;
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
@ -615,19 +666,15 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (!best_strategy) {
|
|
|
|
|
LOG(ERROR)
|
|
|
|
|
<< "This model makes poor memory optimize, skip memory optimize";
|
|
|
|
|
LOG(ERROR) << "This model makes poor memory optimize, skip memory optimize";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
auto memory_allocation = (*best_strategy)();
|
|
|
|
|
|
|
|
|
|
string::PrettyLogH2(
|
|
|
|
|
string::PrettyLogInfo(
|
|
|
|
|
"--- Saved %.2f%s memory for workspace(temporary variables)",
|
|
|
|
|
memory_allocation.GetSavingRatio() * 100, "%");
|
|
|
|
|
string::PrettyLogDetail("--- Allocated %d MB",
|
|
|
|
|
memory_allocation.allocated / 1024. / 1024.);
|
|
|
|
|
string::PrettyLogDetail("--- Saved %d MB",
|
|
|
|
|
memory_allocation.saved / 1024. / 1024.);
|
|
|
|
|
|
|
|
|
|
argument->main_graph().Set(framework::ir::kGraphToProgramVarsToRemove,
|
|
|
|
|
new std::unordered_set<std::string>);
|
|
|
|
|
auto& vars2remove =
|
|
|
|
@ -636,7 +683,6 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
|
|
|
|
|
|
|
|
|
|
PerformReusePlan(reuse_table, memory_allocation.sort_kind, &vars2remove);
|
|
|
|
|
argument->SetMemoryOptimSortKind(memory_allocation.sort_kind);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float MemoryOptimizePass::MemoryAllocation::GetSavingRatio() const {
|
|
|
|
|