optimize memory reuse according streams info

pull/2754/head
laiyongqiang 5 years ago
parent 2f565f4c20
commit ff80587ca7

@ -99,6 +99,11 @@ uint8_t *MemoryManager::MallocStaticMem(size_t size, bool communication_mem) {
} else {
align_size = GetCommonAlignSize(size);
}
MS_LOG(INFO) << "Malloc Memory for Static: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_
<< "] static[" << total_static_size_ << "])"
<< " malloc [" << align_size << "] communication_mem: " << communication_mem;
if (static_mem_offset_ < align_size) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_
<< "] static[" << total_static_size_ << "])"
@ -126,6 +131,11 @@ uint8_t *MemoryManager::MallocDynamicMem(size_t size, bool communication_mem) {
} else {
align_size = GetCommonAlignSize(size);
}
MS_LOG(INFO) << "Malloc Memory for Dynamic: total[" << device_mem_size_ << "](dynamic[" << total_dynamic_size_
<< "] static[" << total_static_size_ << "])"
<< " malloc [" << align_size << "] communication_mem: " << communication_mem;
uint64_t offset = dynamic_mem_offset_;
auto new_offset = dynamic_mem_offset_ + align_size;
if (new_offset > static_mem_offset_) {

@ -329,22 +329,25 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
return;
}
size_t total_summary_size = 0;
for (auto &node_item : summary_nodes) {
auto node = node_item.second.first;
size_t index = IntToSize(node_item.second.second);
MS_LOG(INFO) << "set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
if (kernel_output_refs_.find(node.get()) != kernel_output_refs_.end()) {
KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
kernel_ref->ref_count_ = kMaxRefCount;
kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
total_summary_size += kernel_ref->size_;
MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
} else {
MS_LOG(WARNING) << "can't find summary node's kernel_def " << node->fullname_with_scope();
MS_LOG(WARNING) << "Can't find summary node's kernel_def " << node->fullname_with_scope() << " index: " << index;
}
}
#ifdef MEM_REUSE_DEBUG
auto graph = *graph_;
MemReuseChecker::GetInstance().CheckMemReuseIR(total_refs_list_, kernel_def_ptr_list_, &graph);
#endif
MS_LOG(INFO) << "Special Tensor total size: SummaryNodes: " << total_summary_size;
}
void MemReuseUtil::SetGraphOutputRefCount() {

@ -17,6 +17,9 @@
#include "pre_activate/mem_reuse/mem_reuse_allocator.h"
#include "pre_activate/mem_reuse/mem_reuse.h"
#include "pre_activate/mem_reuse/mem_reuse_checker.h"
#ifdef ENABLE_D
#include "device/ascend/ascend_stream_assign.h"
#endif
namespace mindspore {
namespace memreuse {
@ -34,6 +37,9 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) {
wk->size_ = AlignMemorySize(wk->size_);
wk->ref_count_ = 1;
}
#ifdef ENABLE_D
stream_groups_ = device::ascend::AscendStreamAssign::GetInstance().get_stream_group();
#endif
}
void BestFitMemReuse::InitKernelDependence() {
@ -63,21 +69,58 @@ void BestFitMemReuse::InitKernelDependence() {
}
}
bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const KernelDefPtr &kernel_prev) {
bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf) {
// determine whether the kernel_curr can reuse kernel_prev's output tensor membuf
MS_EXCEPTION_IF_NULL(kernel_curr);
MS_EXCEPTION_IF_NULL(mem_buf);
auto kernel_prev = mem_buf->used_kernel_;
MS_EXCEPTION_IF_NULL(kernel_prev);
auto curr_stream_id = kernel_curr->stream_id();
auto prev_stream_id = kernel_prev->stream_id();
if (curr_stream_id == prev_stream_id) {
mem_buf->type_ = IN_STREAM_REUSE;
return true;
}
bool reuse_between_streams = true;
for (auto &stream_group : stream_groups_) {
size_t cur_index = UINT32_MAX;
size_t prev_index = UINT32_MAX;
for (size_t index = 0; index < stream_group.size(); index++) {
if (curr_stream_id == stream_group[index]) {
cur_index = index;
continue;
}
if (prev_stream_id == stream_group[index]) {
prev_index = index;
continue;
}
}
if ((prev_index != UINT32_MAX) && (cur_index == UINT32_MAX || (prev_index > cur_index))) {
// previous stream and current stream are not in the same group can't be reused
// previous stream is behind current stream can't be reused
reuse_between_streams = false;
break;
}
}
if (reuse_between_streams) {
mem_buf->type_ = BETWEEN_STREAMS_REUSE;
return true;
}
auto iter = kernel_front_map_.find(kernel_curr);
if (iter == kernel_front_map_.end()) {
MS_LOG(EXCEPTION) << kernel_curr->scope_full_name() << " is not init.";
}
auto kernel_curr_front = iter->second;
return kernel_curr_front.count(kernel_prev);
auto depend_count = kernel_curr_front.count(kernel_prev);
if (depend_count) {
mem_buf->type_ = KERNEL_DEPENDENCE_REUSE;
return true;
}
return false;
}
void BestFitMemReuse::AssignNodeOutputOffset() {
@ -135,7 +178,7 @@ std::map<size_t, size_t> BestFitMemReuse::GetReusableMembufMap(size_t tensor_siz
auto membuf = membuf_ptr_list_[i];
auto index = i;
bool is_membuf_ok = membuf->status_ == kUnused && membuf->size_ >= tensor_size;
if (is_membuf_ok && IsUsable(current_kernel_, membuf->used_kernel_)) {
if (is_membuf_ok && IsUsable(current_kernel_, membuf)) {
(void)size_map.insert(std::make_pair(membuf->size_, index));
break;
}
@ -163,8 +206,8 @@ void BestFitMemReuse::SplitMembuf(const KernelRefCount *tensor_desc, size_t memb
auto bias = membuf->size_ - tensor_desc->size_;
membuf->size_ = tensor_desc->size_;
// to check if spilt membuf can be merge
auto new_membuf =
std::make_shared<Membuf>(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex, current_kernel_);
auto new_membuf = std::make_shared<Membuf>(kUnused, bias, membuf->offset_ + membuf->size_, kInvalidIndex,
membuf->type_, current_kernel_);
(void)membuf_ptr_list_.insert(membuf_ptr_list_.begin() + SizeToInt(membuf_index + 1), new_membuf);
}
@ -176,7 +219,7 @@ void BestFitMemReuse::AddNewMembufPtr(KernelRefCount *tensor_desc, int flag) {
}
auto membuf_size = tensor_desc->size_;
auto real_index = GetRealIndex(IntToSize(tensor_desc->index_), flag);
auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, current_kernel_);
auto membuf = std::make_shared<Membuf>(kReused, membuf_size, membuf_offset, real_index, NEW, current_kernel_);
membuf_ptr_list_.push_back(membuf);
tensor_desc->offset_ = membuf_offset;
}
@ -242,7 +285,7 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) {
auto membuf_next = (*next_iter);
MS_EXCEPTION_IF_NULL(membuf_next);
if (membuf_next->status_ == kUnused) {
bool is_merge = IsUsable(current_kernel_, membuf_next->used_kernel_);
bool is_merge = IsUsable(current_kernel_, membuf_next);
if (is_merge) {
membuf->size_ += membuf_next->size_;
(void)membuf_ptr_list_.erase(next_iter);
@ -254,7 +297,7 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) {
auto membuf_prev = (*prev_iter);
MS_EXCEPTION_IF_NULL(membuf_prev);
if (membuf_prev->status_ == kUnused) {
bool is_merge = IsUsable(current_kernel_, membuf_prev->used_kernel_);
bool is_merge = IsUsable(current_kernel_, membuf_prev);
if (is_merge) {
membuf->size_ += membuf_prev->size_;
membuf->offset_ = membuf_prev->offset_;

@ -40,11 +40,12 @@ static constexpr int kDynamicMem = -1;
static constexpr int kWorkspaceMem = 1;
static constexpr size_t kTotalSize = 0;
enum Status { kUnused, kReused };
enum MEMTYPE { NEW, IN_STREAM_REUSE, BETWEEN_STREAMS_REUSE, KERNEL_DEPENDENCE_REUSE };
class Membuf {
public:
Membuf() = default;
Membuf(Status status, size_t size, size_t offset, int index, const KernelDefPtr &used_kernel)
: status_(status), size_(size), offset_(offset), index_(index), used_kernel_(used_kernel) {}
Membuf(Status status, size_t size, size_t offset, int index, MEMTYPE type, const KernelDefPtr &used_kernel)
: status_(status), size_(size), offset_(offset), index_(index), type_(type), used_kernel_(used_kernel) {}
~Membuf() = default;
// Memory block status flags
Status status_ = kUnused;
@ -52,6 +53,7 @@ class Membuf {
size_t offset_{0};
// Store the tensor index stored in this memory block at a certain moment
int index_{0};
MEMTYPE type_{NEW};
KernelDefPtr used_kernel_;
};
using MembufPtr = std::shared_ptr<Membuf>;
@ -122,10 +124,10 @@ class BestFitMemReuse {
/**
* determine if the kernel_curr can reuse the output tensor add of kernel_prev
* @param kernel_curr, current kernel
* @param kernel_prev, the membuf used by this kernel
* @param mem_buf, the membuf
* @return bool
*/
bool IsUsable(const KernelDefPtr &kernel_curr, const KernelDefPtr &kernel_prev);
bool IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr &mem_buf);
/**
* init the dependence of all kernels in the graph
*/
@ -150,6 +152,7 @@ class BestFitMemReuse {
std::vector<MembufPtr> membuf_ptr_list_;
// kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def
std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
std::vector<std::vector<uint32_t>> stream_groups_;
};
} // namespace memreuse
} // namespace mindspore

@ -413,7 +413,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) {
void MemReuseChecker::SetMembuInfos(const KernelDef *op_def, const std::vector<MembufPtr> &membuf_ptr_list) {
std::vector<MembufPtr> curr_mem_infos;
for (const auto &mem : membuf_ptr_list) {
auto mem_checker = std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->used_kernel_);
auto mem_checker =
std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_);
curr_mem_infos.push_back(mem_checker);
}
membuf_all_infos_.push_back(curr_mem_infos);
@ -427,7 +428,8 @@ void MemReuseChecker::SetAddNewMembuInfos(const KernelDef *op_def, const std::ve
std::vector<MembufPtr> add_new_curr_mem;
for (const auto &mem : membuf_ptr_list) {
auto mem_checker = std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->used_kernel_);
auto mem_checker =
std::make_shared<Membuf>(mem->status_, mem->size_, mem->offset_, mem->index_, mem->type_, mem->used_kernel_);
add_new_curr_mem.push_back(mem_checker);
}
add_new_mem_infos_.push_back(add_new_curr_mem);
@ -451,6 +453,7 @@ void MemReuseChecker::ExportEachMembufInfo(std::ofstream &ofs) {
<< "mem_size\t"
<< "mem_head\t"
<< "mem_tail\t"
<< "mem_type\t"
<< "used_kernel\n";
size_t curr_used = 0;
size_t curr_allocated = 0;
@ -461,8 +464,8 @@ void MemReuseChecker::ExportEachMembufInfo(std::ofstream &ofs) {
<< "streamID[@" << membuf->used_kernel_->stream_id() << "]"
<< "\t"
<< "#" << static_cast<int>(membuf->status_) << "\t%" << membuf->index_ << "T"
<< "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t" << membuf->offset_ + membuf->size_ << "\t"
<< GetSplitName(used_kernel) << "\n";
<< "\t" << membuf->size_ << "\t" << membuf->offset_ << "\t\t" << membuf->offset_ + membuf->size_ << "\t"
<< "\t" << static_cast<int>(membuf->type_) << "\t" << GetSplitName(used_kernel) << "\n";
if (membuf->status_ == kReused) {
curr_used += membuf->size_;
}

Loading…
Cancel
Save