reuse communication op output's memory

pull/2966/head
laiyongqiang 5 years ago
parent 45ad430af2
commit 68c78ab6bb

@ -25,7 +25,8 @@
namespace mindspore { namespace mindspore {
namespace memreuse { namespace memreuse {
enum RefCountType { kDynamicRefCount, kStaticRefCount }; enum RefCountType { kDynamicRefCount, kStaticRefCount };
enum NodeType { NORMAL, SPECIAL }; enum NodeType { COMMON_NODE, COMMUNICATION_NODE };
enum KernelRefType { COMMON, REFNODE_OUTPUT, COMM_NOTREUSE, COMM_REUSE, SUMMARY };
static constexpr int kInitIndex = -1; static constexpr int kInitIndex = -1;
class KernelRefCount { class KernelRefCount {
public: public:
@ -36,6 +37,7 @@ class KernelRefCount {
size_t offset_; size_t offset_;
size_t size_; size_t size_;
int index_; int index_;
KernelRefType type_;
// remember to reset offset // remember to reset offset
KernelRefCount() KernelRefCount()
: stream_id_(0), : stream_id_(0),
@ -44,6 +46,7 @@ class KernelRefCount {
offset_(0), offset_(0),
size_(0), size_(0),
index_(kInitIndex), index_(kInitIndex),
type_(COMMON),
reftype_(kStaticRefCount) {} reftype_(kStaticRefCount) {}
~KernelRefCount() = default; ~KernelRefCount() = default;
void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype); void SetKernelRefCountInfo(int index, size_t size, RefCountType reftype);
@ -65,7 +68,7 @@ class KernelDef {
KernelMap inputs_; KernelMap inputs_;
KernelMap outputs_; KernelMap outputs_;
KernelMap wk_space_; KernelMap wk_space_;
NodeType dirty = NORMAL; NodeType type_ = COMMON_NODE;
KernelDef() = default; KernelDef() = default;
~KernelDef() = default; ~KernelDef() = default;
void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; } void set_input_refs(const KernelRefCountPtrList &kernelRefPtrList) { input_refs_ = kernelRefPtrList; }

@ -46,6 +46,8 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() {
if (iter == kernel_output_refs_.end()) { if (iter == kernel_output_refs_.end()) {
auto output_sizes = kernel_mod->GetOutputSizeList(); auto output_sizes = kernel_mod->GetOutputSizeList();
KernelRefCountPtrList kernel_refs; KernelRefCountPtrList kernel_refs;
bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel_cnode);
size_t output_index = 0;
for (auto size : output_sizes) { for (auto size : output_sizes) {
total_dy_size_ += size; total_dy_size_ += size;
// do not MallocDynamicMem just record this // do not MallocDynamicMem just record this
@ -54,9 +56,20 @@ bool MemReuseUtil::InitDynamicOutputKernelRef() {
auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode); auto curr_stream_id = AnfAlgo::GetStreamId(kernel_cnode);
kernel_ref->stream_id_ = curr_stream_id; kernel_ref->stream_id_ = curr_stream_id;
kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount); kernel_ref->SetKernelRefCountInfo(index, size, kDynamicRefCount);
if (is_comm_op) {
kernel_ref->type_ = COMM_REUSE;
} else {
session::AnfWithOutIndex out_pair(kernel_cnode, output_index);
if (graph_->IsInRefOutputMap(out_pair)) {
kernel_ref->type_ = REFNODE_OUTPUT;
} else {
kernel_ref->type_ = COMMON;
}
}
kernel_refs.push_back(kernel_ref); kernel_refs.push_back(kernel_ref);
kernel_out_ref_num++; kernel_out_ref_num++;
total_refs_list_.push_back(kernel_ref); total_refs_list_.push_back(kernel_ref);
output_index++;
} }
if (!kernel_refs.empty()) { if (!kernel_refs.empty()) {
kernel_output_refs_[key] = kernel_refs; kernel_output_refs_[key] = kernel_refs;
@ -155,9 +168,19 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr
MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(kernel_def_ptr); MS_EXCEPTION_IF_NULL(kernel_def_ptr);
auto key = kernel.get(); auto key = kernel.get();
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { bool is_comm_op = AnfAlgo::IsCommunicationOp(kernel);
size_t input_tensor_num = AnfAlgo::GetInputTensorNum(kernel);
for (size_t i = 0; i < input_tensor_num; ++i) {
auto ref_ptr = GetKernelInputRef(kernel, i); auto ref_ptr = GetKernelInputRef(kernel, i);
if (ref_ptr != nullptr) { if (ref_ptr != nullptr) {
if (is_comm_op) {
if (input_tensor_num == 1) {
ref_ptr->type_ = COMM_REUSE;
} else {
ref_ptr->type_ = COMM_NOTREUSE;
}
}
if (ref_ptr->reftype() == kStaticRefCount) { if (ref_ptr->reftype() == kStaticRefCount) {
continue; continue;
} else if (ref_ptr->reftype() == kDynamicRefCount) { } else if (ref_ptr->reftype() == kDynamicRefCount) {
@ -258,6 +281,11 @@ void MemReuseUtil::SetKernelDefMap() {
auto key = kernel.get(); auto key = kernel.get();
kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]); kernel_def_ptr->set_input_refs(kernel_def_ptr->inputs_[key]);
kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]); kernel_def_ptr->set_output_refs(kernel_def_ptr->outputs_[key]);
if (AnfAlgo::IsCommunicationOp(kernel)) {
kernel_def_ptr->type_ = COMMUNICATION_NODE;
} else {
kernel_def_ptr->type_ = COMMON_NODE;
}
kernel_def_ptr_list_.push_back(kernel_def_ptr); kernel_def_ptr_list_.push_back(kernel_def_ptr);
kernel_map_[key] = kernel_def_ptr; kernel_map_[key] = kernel_def_ptr;
} }
@ -337,6 +365,7 @@ void MemReuseUtil::SetSummaryNodesRefCount() {
KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index]; KernelRefCountPtr kernel_ref = kernel_output_refs_[node.get()][index];
kernel_ref->ref_count_ = kMaxRefCount; kernel_ref->ref_count_ = kMaxRefCount;
kernel_ref->ref_count_dynamic_use_ = kMaxRefCount; kernel_ref->ref_count_dynamic_use_ = kMaxRefCount;
kernel_ref->type_ = SUMMARY;
total_summary_size += kernel_ref->size_; total_summary_size += kernel_ref->size_;
MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index; MS_LOG(INFO) << "Set summary node's ref count, node: " << node->fullname_with_scope() << " index: " << index;
} else { } else {

@ -30,11 +30,11 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) {
set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list());
// check info Correctness // check info Correctness
for (auto &tensor : tensor_ptr_list_) { for (auto &tensor : tensor_ptr_list_) {
tensor->size_ = AlignMemorySize(tensor->size_); tensor->size_ = AlignCommonMemorySize(tensor->size_);
} }
// align wk size to 512 && refcount == 1 // align wk size to 512 && refcount == 1
for (auto &wk : wk_tensor_list_) { for (auto &wk : wk_tensor_list_) {
wk->size_ = AlignMemorySize(wk->size_); wk->size_ = AlignCommonMemorySize(wk->size_);
wk->ref_count_ = 1; wk->ref_count_ = 1;
} }
#ifdef ENABLE_D #ifdef ENABLE_D
@ -123,11 +123,23 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr
return false; return false;
} }
void BestFitMemReuse::AssignNodeOutputOffset() { void BestFitMemReuse::AssignCommonNodeOutputOffset() {
MS_EXCEPTION_IF_NULL(current_kernel_);
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx); size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index]; auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc); MS_EXCEPTION_IF_NULL(tensor_desc);
if (tensor_desc->type_ == REFNODE_OUTPUT) {
total_refoutput_size += tensor_desc->size_;
continue;
} else if (tensor_desc->type_ == COMM_NOTREUSE) {
total_comm_not_reuse_size += tensor_desc->size_;
} else if (tensor_desc->type_ == COMM_REUSE) {
// get align size for communication op's single input
tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_);
total_comm_reuse_size += tensor_desc->size_;
}
auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_);
if (!reusable_membuf_map.empty()) { if (!reusable_membuf_map.empty()) {
auto membuf_index = reusable_membuf_map.begin()->second; auto membuf_index = reusable_membuf_map.begin()->second;
@ -138,11 +150,91 @@ void BestFitMemReuse::AssignNodeOutputOffset() {
AddNewMembufPtr(tensor_desc.get(), kDynamicMem); AddNewMembufPtr(tensor_desc.get(), kDynamicMem);
#ifdef MEM_REUSE_DEBUG #ifdef MEM_REUSE_DEBUG
MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; MemReuseChecker::GetInstance().IsAddNewMembuf_ = true;
#endif
}
// skip left align border for communication op single input to used
if (tensor_desc->type_ == COMM_REUSE) {
tensor_desc->offset_ += kDefaultMemAlignSize;
}
}
}
void BestFitMemReuse::AssignCommunicationNodeOutputOffset() {
size_t total_kernel_output_size = 0;
size_t output_num = 0;
// get all output size
MS_EXCEPTION_IF_NULL(current_kernel_);
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
if (tensor_desc->type_ == COMM_REUSE) {
total_comm_reuse_size += tensor_desc->size_;
total_comm_output_reuse_size += tensor_desc->size_;
total_kernel_output_size += tensor_desc->size_;
} else {
MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:"
<< current_kernel_->scope_full_name();
continue;
}
}
total_kernel_output_size = AlignCommunicationMemorySize(total_kernel_output_size);
// add left align border for the first output and right align border for the last output to alloc align border memory
size_t output_index = 0;
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
if (output_index == 0 || output_index == output_num - 1) {
tensor_desc->size_ += kDefaultMemAlignSize;
}
output_index++;
}
auto reusable_membuf_map = GetReusableMembufMap(total_kernel_output_size);
if (!reusable_membuf_map.empty()) {
auto membuf_index = reusable_membuf_map.begin()->second;
output_index = 0;
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
ReuseExistMembuf(tensor_desc.get(), membuf_index + output_index, kDynamicMem);
// skip skip left align border for communication op's first output to used
if (output_index == 0) {
tensor_desc->offset_ += kDefaultMemAlignSize;
}
output_index++;
}
} else {
// no membuf can reuse, add new membuf after the membuf_ptr_list
output_index = 0;
for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) {
size_t index = GetTensorIndex(tensor_idx);
auto tensor_desc = tensor_ptr_list_[index];
MS_EXCEPTION_IF_NULL(tensor_desc);
AddNewMembufPtr(tensor_desc.get(), kDynamicMem);
// skip align size offset for first output to used
if (output_index == 0) {
tensor_desc->offset_ += kDefaultMemAlignSize;
}
output_index++;
#ifdef MEM_REUSE_DEBUG
MemReuseChecker::GetInstance().IsAddNewMembuf_ = true;
#endif #endif
} }
} }
} }
void BestFitMemReuse::AssignNodeOutputOffset() {
if (current_kernel_->type_ == COMMUNICATION_NODE) {
AssignCommunicationNodeOutputOffset();
} else {
AssignCommonNodeOutputOffset();
}
}
void BestFitMemReuse::AssignNodeWorkspaceOffset() { void BestFitMemReuse::AssignNodeWorkspaceOffset() {
for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) { for (auto &wk_idx : current_kernel_->GetWorkspaceRefIndexs()) {
size_t index = GetWorkspaceIndex(wk_idx); size_t index = GetWorkspaceIndex(wk_idx);
@ -307,11 +399,17 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) {
} }
} }
size_t BestFitMemReuse::AlignMemorySize(size_t size) const { size_t BestFitMemReuse::AlignCommonMemorySize(size_t size) const {
// memory size 512 align // memory size 512 align
return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize;
} }
size_t BestFitMemReuse::AlignCommunicationMemorySize(size_t size) const {
// memory size 512 align and add communication memory: left align border memory - data - right align border memory
return kDefaultMemAlignSize + (size + kDefaultMemAlignSize - 1) / kDefaultMemAlignSize * kDefaultMemAlignSize +
kDefaultMemAlignSize;
}
size_t BestFitMemReuse::GetAllocatedSize() { size_t BestFitMemReuse::GetAllocatedSize() {
size_t AllocatedSize = kTotalSize; size_t AllocatedSize = kTotalSize;
if (membuf_ptr_list_.empty()) { if (membuf_ptr_list_.empty()) {
@ -400,6 +498,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) {
++op_num; ++op_num;
#endif #endif
} }
MS_LOG(INFO) << "Special Tensor total size: RefOutput: " << total_refoutput_size
<< " CommReuse: " << total_comm_reuse_size << " CommOutputReuse: " << total_comm_output_reuse_size
<< " CommNotReuse: " << total_comm_not_reuse_size;
#ifdef MEM_REUSE_DEBUG #ifdef MEM_REUSE_DEBUG
MemReuseChecker::GetInstance().ExportMembufInfoIR(); MemReuseChecker::GetInstance().ExportMembufInfoIR();
MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); MemReuseChecker::GetInstance().ExportAddNewMmebufIR();

@ -74,6 +74,14 @@ class BestFitMemReuse {
* Assign output tensor memory offset of current kernel * Assign output tensor memory offset of current kernel
*/ */
void AssignNodeOutputOffset(); void AssignNodeOutputOffset();
/**
* Assign output tensor memory offset of common kernel
*/
void AssignCommonNodeOutputOffset();
/**
* Assign output tensor memory offset of communication kernel
*/
void AssignCommunicationNodeOutputOffset();
/** /**
* Update input tensor's status of current kernel, and the status of membuf used by current kernel * Update input tensor's status of current kernel, and the status of membuf used by current kernel
*/ */
@ -110,8 +118,10 @@ class BestFitMemReuse {
void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag); void AddNewMembufPtr(KernelRefCount *tensor_desc, int flag);
// Merge unused membuf // Merge unused membuf
void ReleaseMembuf(size_t tensor_index, int flag); void ReleaseMembuf(size_t tensor_index, int flag);
// Memory address alignment 512 // Memory address alignment for common memory
size_t AlignMemorySize(size_t size) const; size_t AlignCommonMemorySize(size_t size) const;
// Memory address alignment for communication used memory
size_t AlignCommunicationMemorySize(size_t size) const;
int GetRealIndex(size_t index, int flag = kDynamicMem) const; int GetRealIndex(size_t index, int flag = kDynamicMem) const;
size_t GetTensorIndex(int index) const; size_t GetTensorIndex(int index) const;
size_t GetWorkspaceIndex(int index) const; size_t GetWorkspaceIndex(int index) const;
@ -153,6 +163,10 @@ class BestFitMemReuse {
// kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def // kernel_front_map_, key: the kernel_def, value: kernels before this kernel_def
std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_; std::map<KernelDefPtr, std::set<KernelDefPtr>> kernel_front_map_;
std::vector<std::vector<uint32_t>> stream_groups_; std::vector<std::vector<uint32_t>> stream_groups_;
size_t total_refoutput_size{0};
size_t total_comm_reuse_size{0};
size_t total_comm_output_reuse_size{0};
size_t total_comm_not_reuse_size{0};
}; };
} // namespace memreuse } // namespace memreuse
} // namespace mindspore } // namespace mindspore

@ -170,12 +170,14 @@ void MemReuseChecker::CheckMemReuseIR(const KernelRefCountPtrList &total_refs_li
ofs << "all_tensor_refs:\n"; ofs << "all_tensor_refs:\n";
ofs << "index:" ofs << "index:"
<< "\tsize:" << "\tsize:"
<< "\trefcount:\n"; << "\trefcount:"
<< "\ttype:\n";
for (auto &ref : total_refs_list) { for (auto &ref : total_refs_list) {
ofs << "%" << ref->index_ << "T" ofs << "%" << ref->index_ << "T"
<< "\t" << "\t"
<< "#" << ref->size_ << "S" << "#" << ref->size_ << "S"
<< "\t" << ref->ref_count_ << "C" << "\t" << ref->ref_count_ << "C"
<< "\t" << ref->type_ << "t"
<< "\n"; << "\n";
} }
ofs << "kernel_def exc_order:\n"; ofs << "kernel_def exc_order:\n";
@ -241,7 +243,7 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph
void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) { void MemReuseChecker::ExportMemOpIr(const KernelDef *def, std::ofstream &ofs, int def_idx) {
auto scope_name = def->scope_full_name(); auto scope_name = def->scope_full_name();
std::string split_name = GetSplitName(scope_name); std::string split_name = GetSplitName(scope_name);
ofs << "$" << def_idx << "\t" << split_name << "\t"; ofs << "$" << def_idx << "\t" << split_name << "\t" << static_cast<int>(def->type_) << "\t";
ofs << "inputs["; ofs << "inputs[";
for (auto &in : def->inputs_) { for (auto &in : def->inputs_) {
for (auto &in_ref : in.second) { for (auto &in_ref : in.second) {

@ -95,6 +95,12 @@ uint8_t *AscendMemoryManager::MallocStaticMem(size_t size, bool communication_me
} else { } else {
align_size = GetCommonAlignSize(size); align_size = GetCommonAlignSize(size);
} }
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
MS_LOG(INFO) << "Malloc Memory: Static, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "] memory pool[" << device_mem_pool_offset << "])"
<< " malloc [" << align_size << "]";
if (communication_mem) { if (communication_mem) {
// create protect area [kMemAlignSize -- data -- kMemAlignSize] // create protect area [kMemAlignSize -- data -- kMemAlignSize]
uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size)); uint8_t *alloc_address = reinterpret_cast<uint8_t *>(AscendMemoryPool::GetInstance().AllocTensorMem(align_size));
@ -111,12 +117,17 @@ uint8_t *AscendMemoryManager::MallocDynamicMem(size_t size, bool communication_m
} else { } else {
align_size = GetCommonAlignSize(size); align_size = GetCommonAlignSize(size);
} }
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
MS_LOG(INFO) << "Malloc Memory: Dynamic, total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "] memory pool[" << device_mem_pool_offset << "])"
<< " malloc [" << align_size << "]";
if (dynamic_mem_offset_ < align_size) { if (dynamic_mem_offset_ < align_size) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "]) malloc [" << align_size << "] failed!"; << "]) malloc [" << align_size << "] failed!";
} }
auto new_offset = dynamic_mem_offset_ - align_size; auto new_offset = dynamic_mem_offset_ - align_size;
auto device_mem_pool_offset = AscendMemoryPool::GetInstance().device_mem_pool_offset();
if (new_offset <= device_mem_pool_offset) { if (new_offset <= device_mem_pool_offset) {
MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_ MS_LOG(EXCEPTION) << "Out of memory!!! total[" << device_mem_size_ << "] (dynamic[" << total_dynamic_size_
<< "] memory pool[" << device_mem_pool_offset << "])" << "] memory pool[" << device_mem_pool_offset << "])"

@ -398,7 +398,7 @@ void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph *graph) {
} }
void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) { void KernelRuntime::AssignCommunicationNodeMem(int flag, const AnfNodePtr &node) {
AssignCommunicationNodeInputMem(node); AssignCommunicationNodeInputMem(flag, node);
AssignCommunicationNodeOutputMem(flag, node); AssignCommunicationNodeOutputMem(flag, node);
} }
@ -428,6 +428,11 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr
total_size += mem_size; total_size += mem_size;
align_size_list.emplace_back(mem_size); align_size_list.emplace_back(mem_size);
} }
if (flag == kReuseDynamicMem) {
// reuse communication op's all outputs' memory
flag = kReuseDynamicCommMem;
}
uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size); uint8_t *output_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
for (size_t j = 0; j < align_size_list.size(); ++j) { for (size_t j = 0; j < align_size_list.size(); ++j) {
std::string output_format = AnfAlgo::GetOutputFormat(node, j); std::string output_format = AnfAlgo::GetOutputFormat(node, j);
@ -456,7 +461,7 @@ DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node,
return address; return address;
} }
void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) { void KernelRuntime::AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
@ -477,7 +482,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(const AnfNodePtr &node) {
total_size += mem_size; total_size += mem_size;
addr_size.emplace_back(address.get(), mem_size); addr_size.emplace_back(address.get(), mem_size);
} }
uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, kDynamicMem, total_size); uint8_t *input_ptr = mem_manager_->MallocOutputMem(node, 0, flag, total_size);
for (const auto &iter : addr_size) { for (const auto &iter : addr_size) {
MS_EXCEPTION_IF_NULL(iter.first); MS_EXCEPTION_IF_NULL(iter.first);
iter.first->set_ptr(input_ptr); iter.first->set_ptr(input_ptr);

@ -88,7 +88,7 @@ class KernelRuntime {
void UpdateRefNodeOutputMem(const session::KernelGraph *graph); void UpdateRefNodeOutputMem(const session::KernelGraph *graph);
void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node); void AssignCommunicationNodeOutputMem(int flag, const AnfNodePtr &node);
void AssignCommunicationNodeInputMem(const AnfNodePtr &node); void AssignCommunicationNodeInputMem(int flag, const AnfNodePtr &node);
void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node); void AssignCommunicationNodeMem(int flag, const AnfNodePtr &node);
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
bool SetDumpConf(); bool SetDumpConf();

@ -57,6 +57,9 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, in
} }
if (flag == kStaticMem) { if (flag == kStaticMem) {
ptr = MallocStaticMem(size, communication_mem); ptr = MallocStaticMem(size, communication_mem);
} else if (flag == kReuseDynamicCommMem) {
MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr_);
ptr = mem_reuse_util_ptr_->GetNodeOutputPtr(node, index);
} else { } else {
ptr = MallocDynamicMem(size, communication_mem); ptr = MallocDynamicMem(size, communication_mem);
} }

@ -25,6 +25,7 @@ namespace device {
const int kStaticMem = 0; const int kStaticMem = 0;
const int kDynamicMem = 1; const int kDynamicMem = 1;
const int kReuseDynamicMem = 2; const int kReuseDynamicMem = 2;
const int kReuseDynamicCommMem = 3;
const int kGetAllOuts = -1; const int kGetAllOuts = -1;
const uint64_t kMemAlignSize = 512; const uint64_t kMemAlignSize = 512;
using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr; using MemReuseUtilPtr = mindspore::memreuse::MemReuseUtilPtr;

@ -146,7 +146,7 @@ TEST_F(TestMemReuseAllocator, mem_reuse_allocator_split_membuf) {
TEST_F(TestMemReuseAllocator, mem_reuse_allocator_align) { TEST_F(TestMemReuseAllocator, mem_reuse_allocator_align) {
auto best_fit_mem_reuse = std::make_shared<BestFitMemReuse>(); auto best_fit_mem_reuse = std::make_shared<BestFitMemReuse>();
auto size = best_fit_mem_reuse->AlignMemorySize(510); auto size = best_fit_mem_reuse->AlignCommonMemorySize(510);
ASSERT_EQ(size, 1024); ASSERT_EQ(size, 1024);
} }
} // namespace memreuse } // namespace memreuse

@ -225,7 +225,6 @@ TEST_F(TestMemReuseWithPy, KernelRef) {
ASSERT_EQ(kernel_ref_count_ptr->size_, 512); ASSERT_EQ(kernel_ref_count_ptr->size_, 512);
KernelDefPtr kernel_def_ptr = std::make_shared<KernelDef>(); KernelDefPtr kernel_def_ptr = std::make_shared<KernelDef>();
ASSERT_NE(kernel_def_ptr, nullptr); ASSERT_NE(kernel_def_ptr, nullptr);
ASSERT_EQ(kernel_def_ptr->dirty, false);
MembufPtr membuf_ptr = std::make_shared<Membuf>(); MembufPtr membuf_ptr = std::make_shared<Membuf>();
ASSERT_NE(membuf_ptr, nullptr); ASSERT_NE(membuf_ptr, nullptr);
} }

Loading…
Cancel
Save