|  |  |  | @ -30,11 +30,11 @@ void BestFitMemReuse::InitMemReuseInfo(const MemReuseUtil *mem_reuse_util_ptr) { | 
			
		
	
		
			
				
					|  |  |  |  |   set_op_ptr_list(mem_reuse_util_ptr->kernel_def_ptr_list()); | 
			
		
	
		
			
				
					|  |  |  |  |   // check info Correctness
 | 
			
		
	
		
			
				
					|  |  |  |  |   for (auto &tensor : tensor_ptr_list_) { | 
			
		
	
		
			
				
					|  |  |  |  |     tensor->size_ = AlignMemorySize(tensor->size_); | 
			
		
	
		
			
				
					|  |  |  |  |     tensor->size_ = AlignCommonMemorySize(tensor->size_); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  |   // align wk size to 512 && refcount == 1
 | 
			
		
	
		
			
				
					|  |  |  |  |   for (auto &wk : wk_tensor_list_) { | 
			
		
	
		
			
				
					|  |  |  |  |     wk->size_ = AlignMemorySize(wk->size_); | 
			
		
	
		
			
				
					|  |  |  |  |     wk->size_ = AlignCommonMemorySize(wk->size_); | 
			
		
	
		
			
				
					|  |  |  |  |     wk->ref_count_ = 1; | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | #ifdef ENABLE_D | 
			
		
	
	
		
			
				
					|  |  |  | @ -123,11 +123,23 @@ bool BestFitMemReuse::IsUsable(const KernelDefPtr &kernel_curr, const MembufPtr | 
			
		
	
		
			
				
					|  |  |  |  |   return false; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | void BestFitMemReuse::AssignNodeOutputOffset() { | 
			
		
	
		
			
				
					|  |  |  |  | void BestFitMemReuse::AssignCommonNodeOutputOffset() { | 
			
		
	
		
			
				
					|  |  |  |  |   MS_EXCEPTION_IF_NULL(current_kernel_); | 
			
		
	
		
			
				
					|  |  |  |  |   for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | 
			
		
	
		
			
				
					|  |  |  |  |     size_t index = GetTensorIndex(tensor_idx); | 
			
		
	
		
			
				
					|  |  |  |  |     auto tensor_desc = tensor_ptr_list_[index]; | 
			
		
	
		
			
				
					|  |  |  |  |     MS_EXCEPTION_IF_NULL(tensor_desc); | 
			
		
	
		
			
				
					|  |  |  |  |     if (tensor_desc->type_ == REFNODE_OUTPUT) { | 
			
		
	
		
			
				
					|  |  |  |  |       total_refoutput_size += tensor_desc->size_; | 
			
		
	
		
			
				
					|  |  |  |  |       continue; | 
			
		
	
		
			
				
					|  |  |  |  |     } else if (tensor_desc->type_ == COMM_NOTREUSE) { | 
			
		
	
		
			
				
					|  |  |  |  |       total_comm_not_reuse_size += tensor_desc->size_; | 
			
		
	
		
			
				
					|  |  |  |  |     } else if (tensor_desc->type_ == COMM_REUSE) { | 
			
		
	
		
			
				
					|  |  |  |  |       // get align size for communication op's single input
 | 
			
		
	
		
			
				
					|  |  |  |  |       tensor_desc->size_ = AlignCommunicationMemorySize(tensor_desc->size_); | 
			
		
	
		
			
				
					|  |  |  |  |       total_comm_reuse_size += tensor_desc->size_; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     auto reusable_membuf_map = GetReusableMembufMap(tensor_desc->size_); | 
			
		
	
		
			
				
					|  |  |  |  |     if (!reusable_membuf_map.empty()) { | 
			
		
	
		
			
				
					|  |  |  |  |       auto membuf_index = reusable_membuf_map.begin()->second; | 
			
		
	
	
		
			
				
					|  |  |  | @ -140,6 +152,86 @@ void BestFitMemReuse::AssignNodeOutputOffset() { | 
			
		
	
		
			
				
					|  |  |  |  |       MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; | 
			
		
	
		
			
				
					|  |  |  |  | #endif | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     // skip left align border for communication op single input to used
 | 
			
		
	
		
			
				
					|  |  |  |  |     if (tensor_desc->type_ == COMM_REUSE) { | 
			
		
	
		
			
				
					|  |  |  |  |       tensor_desc->offset_ += kDefaultMemAlignSize; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | void BestFitMemReuse::AssignCommunicationNodeOutputOffset() { | 
			
		
	
		
			
				
					|  |  |  |  |   size_t total_kernel_output_size = 0; | 
			
		
	
		
			
				
					|  |  |  |  |   size_t output_num = 0; | 
			
		
	
		
			
				
					|  |  |  |  |   // get all output size
 | 
			
		
	
		
			
				
					|  |  |  |  |   MS_EXCEPTION_IF_NULL(current_kernel_); | 
			
		
	
		
			
				
					|  |  |  |  |   for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | 
			
		
	
		
			
				
					|  |  |  |  |     size_t index = GetTensorIndex(tensor_idx); | 
			
		
	
		
			
				
					|  |  |  |  |     auto tensor_desc = tensor_ptr_list_[index]; | 
			
		
	
		
			
				
					|  |  |  |  |     MS_EXCEPTION_IF_NULL(tensor_desc); | 
			
		
	
		
			
				
					|  |  |  |  |     if (tensor_desc->type_ == COMM_REUSE) { | 
			
		
	
		
			
				
					|  |  |  |  |       total_comm_reuse_size += tensor_desc->size_; | 
			
		
	
		
			
				
					|  |  |  |  |       total_comm_output_reuse_size += tensor_desc->size_; | 
			
		
	
		
			
				
					|  |  |  |  |       total_kernel_output_size += tensor_desc->size_; | 
			
		
	
		
			
				
					|  |  |  |  |     } else { | 
			
		
	
		
			
				
					|  |  |  |  |       MS_LOG(ERROR) << "All communication op's outputs should be memory reuse, Kernel:" | 
			
		
	
		
			
				
					|  |  |  |  |                     << current_kernel_->scope_full_name(); | 
			
		
	
		
			
				
					|  |  |  |  |       continue; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  |   total_kernel_output_size = AlignCommunicationMemorySize(total_kernel_output_size); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   // add left align border for the first output and right align border for the last output to alloc align border memory
 | 
			
		
	
		
			
				
					|  |  |  |  |   size_t output_index = 0; | 
			
		
	
		
			
				
					|  |  |  |  |   for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | 
			
		
	
		
			
				
					|  |  |  |  |     size_t index = GetTensorIndex(tensor_idx); | 
			
		
	
		
			
				
					|  |  |  |  |     auto tensor_desc = tensor_ptr_list_[index]; | 
			
		
	
		
			
				
					|  |  |  |  |     MS_EXCEPTION_IF_NULL(tensor_desc); | 
			
		
	
		
			
				
					|  |  |  |  |     if (output_index == 0 || output_index == output_num - 1) { | 
			
		
	
		
			
				
					|  |  |  |  |       tensor_desc->size_ += kDefaultMemAlignSize; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     output_index++; | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |   auto reusable_membuf_map = GetReusableMembufMap(total_kernel_output_size); | 
			
		
	
		
			
				
					|  |  |  |  |   if (!reusable_membuf_map.empty()) { | 
			
		
	
		
			
				
					|  |  |  |  |     auto membuf_index = reusable_membuf_map.begin()->second; | 
			
		
	
		
			
				
					|  |  |  |  |     output_index = 0; | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | 
			
		
	
		
			
				
					|  |  |  |  |       size_t index = GetTensorIndex(tensor_idx); | 
			
		
	
		
			
				
					|  |  |  |  |       auto tensor_desc = tensor_ptr_list_[index]; | 
			
		
	
		
			
				
					|  |  |  |  |       MS_EXCEPTION_IF_NULL(tensor_desc); | 
			
		
	
		
			
				
					|  |  |  |  |       ReuseExistMembuf(tensor_desc.get(), membuf_index + output_index, kDynamicMem); | 
			
		
	
		
			
				
					|  |  |  |  |       // skip skip left align border for communication op's first output to used
 | 
			
		
	
		
			
				
					|  |  |  |  |       if (output_index == 0) { | 
			
		
	
		
			
				
					|  |  |  |  |         tensor_desc->offset_ += kDefaultMemAlignSize; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |       output_index++; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } else { | 
			
		
	
		
			
				
					|  |  |  |  |     // no membuf can reuse, add new membuf after the membuf_ptr_list
 | 
			
		
	
		
			
				
					|  |  |  |  |     output_index = 0; | 
			
		
	
		
			
				
					|  |  |  |  |     for (auto &tensor_idx : current_kernel_->GetOutputRefIndexs()) { | 
			
		
	
		
			
				
					|  |  |  |  |       size_t index = GetTensorIndex(tensor_idx); | 
			
		
	
		
			
				
					|  |  |  |  |       auto tensor_desc = tensor_ptr_list_[index]; | 
			
		
	
		
			
				
					|  |  |  |  |       MS_EXCEPTION_IF_NULL(tensor_desc); | 
			
		
	
		
			
				
					|  |  |  |  |       AddNewMembufPtr(tensor_desc.get(), kDynamicMem); | 
			
		
	
		
			
				
					|  |  |  |  |       // skip align size offset for first output to used
 | 
			
		
	
		
			
				
					|  |  |  |  |       if (output_index == 0) { | 
			
		
	
		
			
				
					|  |  |  |  |         tensor_desc->offset_ += kDefaultMemAlignSize; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
		
			
				
					|  |  |  |  |       output_index++; | 
			
		
	
		
			
				
					|  |  |  |  | #ifdef MEM_REUSE_DEBUG | 
			
		
	
		
			
				
					|  |  |  |  |       MemReuseChecker::GetInstance().IsAddNewMembuf_ = true; | 
			
		
	
		
			
				
					|  |  |  |  | #endif | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | void BestFitMemReuse::AssignNodeOutputOffset() { | 
			
		
	
		
			
				
					|  |  |  |  |   if (current_kernel_->type_ == COMMUNICATION_NODE) { | 
			
		
	
		
			
				
					|  |  |  |  |     AssignCommunicationNodeOutputOffset(); | 
			
		
	
		
			
				
					|  |  |  |  |   } else { | 
			
		
	
		
			
				
					|  |  |  |  |     AssignCommonNodeOutputOffset(); | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -307,11 +399,17 @@ void BestFitMemReuse::ReleaseMembuf(size_t tensor_index, int flag) { | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | size_t BestFitMemReuse::AlignMemorySize(size_t size) const { | 
			
		
	
		
			
				
					|  |  |  |  | size_t BestFitMemReuse::AlignCommonMemorySize(size_t size) const { | 
			
		
	
		
			
				
					|  |  |  |  |   // memory size 512 align
 | 
			
		
	
		
			
				
					|  |  |  |  |   return (size + kDefaultMemAlignSize + kAttAlignSize) / kDefaultMemAlignSize * kDefaultMemAlignSize; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | size_t BestFitMemReuse::AlignCommunicationMemorySize(size_t size) const { | 
			
		
	
		
			
				
					|  |  |  |  |   // memory size 512 align and add communication memory:  left align border memory - data - right align border memory
 | 
			
		
	
		
			
				
					|  |  |  |  |   return kDefaultMemAlignSize + (size + kDefaultMemAlignSize - 1) / kDefaultMemAlignSize * kDefaultMemAlignSize + | 
			
		
	
		
			
				
					|  |  |  |  |          kDefaultMemAlignSize; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | size_t BestFitMemReuse::GetAllocatedSize() { | 
			
		
	
		
			
				
					|  |  |  |  |   size_t AllocatedSize = kTotalSize; | 
			
		
	
		
			
				
					|  |  |  |  |   if (membuf_ptr_list_.empty()) { | 
			
		
	
	
		
			
				
					|  |  |  | @ -400,6 +498,9 @@ void BestFitMemReuse::Reuse(const MemReuseUtil *mem_reuse_util_ptr) { | 
			
		
	
		
			
				
					|  |  |  |  |     ++op_num; | 
			
		
	
		
			
				
					|  |  |  |  | #endif | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  |   MS_LOG(INFO) << "Special Tensor total size: RefOutput: " << total_refoutput_size | 
			
		
	
		
			
				
					|  |  |  |  |                << " CommReuse: " << total_comm_reuse_size << " CommOutputReuse: " << total_comm_output_reuse_size | 
			
		
	
		
			
				
					|  |  |  |  |                << " CommNotReuse: " << total_comm_not_reuse_size; | 
			
		
	
		
			
				
					|  |  |  |  | #ifdef MEM_REUSE_DEBUG | 
			
		
	
		
			
				
					|  |  |  |  |   MemReuseChecker::GetInstance().ExportMembufInfoIR(); | 
			
		
	
		
			
				
					|  |  |  |  |   MemReuseChecker::GetInstance().ExportAddNewMmebufIR(); | 
			
		
	
	
		
			
				
					|  |  |  | 
 |