|
|
|
@ -25,10 +25,7 @@ namespace memswap {
|
|
|
|
|
void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
execution_order_ = kernel_graph->execution_order();
|
|
|
|
|
FuncGraphManagerPtr manager = kernel_graph->manager();
|
|
|
|
|
NodeUsersMap user_map = manager->node_users();
|
|
|
|
|
size_t kernel_index = 0;
|
|
|
|
|
|
|
|
|
|
for (const auto &kernel : execution_order_) {
|
|
|
|
|
// parse topo order of kernel
|
|
|
|
|
kernel_execution_info_.emplace(kernel.get(), kernel_index++);
|
|
|
|
@ -44,6 +41,31 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// parse topo order of user kernel
|
|
|
|
|
SaveUserKernelTopoOrder(kernel_graph);
|
|
|
|
|
|
|
|
|
|
sort(ordered_tensors_.begin(), ordered_tensors_.end(),
|
|
|
|
|
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; });
|
|
|
|
|
|
|
|
|
|
auto cur_tensor_size = ordered_tensors_.front().tensor_size_;
|
|
|
|
|
for (auto &tensor_info : ordered_tensors_) {
|
|
|
|
|
if (cur_tensor_size != tensor_info.tensor_size_) {
|
|
|
|
|
cur_tensor_size = tensor_info.tensor_size_;
|
|
|
|
|
tensor_size_num_++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tensor_size_threshold_ = ordered_tensors_.front().tensor_size_;
|
|
|
|
|
tensor_size_threshold_idx_ = 0;
|
|
|
|
|
|
|
|
|
|
distance_threshold_ = kernel_index / kDistanceInitFactor;
|
|
|
|
|
mem_swap_initialized_ = true;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mem_copy_manager_);
|
|
|
|
|
mem_copy_manager_->Init();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
FuncGraphManagerPtr manager = kernel_graph->manager();
|
|
|
|
|
NodeUsersMap user_map = manager->node_users();
|
|
|
|
|
for (const auto &kernel : execution_order_) {
|
|
|
|
|
auto iter = user_map.find(kernel);
|
|
|
|
|
if (iter == user_map.end()) {
|
|
|
|
@ -66,24 +88,6 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
sort(node_user_pair.second.begin(), node_user_pair.second.end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sort(ordered_tensors_.begin(), ordered_tensors_.end(),
|
|
|
|
|
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; });
|
|
|
|
|
|
|
|
|
|
auto cur_tensor_size = ordered_tensors_.front().tensor_size_;
|
|
|
|
|
for (auto &tensor_info : ordered_tensors_) {
|
|
|
|
|
if (cur_tensor_size != tensor_info.tensor_size_) {
|
|
|
|
|
cur_tensor_size = tensor_info.tensor_size_;
|
|
|
|
|
tensor_size_num_++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
tensor_size_threshold_ = ordered_tensors_.front().tensor_size_;
|
|
|
|
|
tensor_size_threshold_idx_ = 0;
|
|
|
|
|
|
|
|
|
|
distance_threshold_ = kernel_index / kDistanceInitFactor;
|
|
|
|
|
mem_swap_initialized_ = true;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mem_copy_manager_);
|
|
|
|
|
mem_copy_manager_->Init();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemSwapManager::AddSwapInfo() {
|
|
|
|
@ -228,12 +232,12 @@ float MemSwapManager::QueryKernelExecutionPerform(const AnfNodePtr &kernel) cons
|
|
|
|
|
return kernel_exec_info.execution_perform_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MemSwapManager::QueryKerneTriggerSwap(const AnfNodePtr &kernel) const {
|
|
|
|
|
bool MemSwapManager::QueryKernelTriggerSwap(const AnfNodePtr &kernel) const {
|
|
|
|
|
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
|
|
|
|
return kernel_exec_info.trigger_swap_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MemSwapManager::QueryKerneNeedSwap(const AnfNodePtr &kernel) const {
|
|
|
|
|
bool MemSwapManager::QueryKernelNeedSwap(const AnfNodePtr &kernel) const {
|
|
|
|
|
const auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
|
|
|
|
return kernel_exec_info.need_swap_;
|
|
|
|
|
}
|
|
|
|
@ -254,7 +258,7 @@ const PerformPair &MemSwapManager::QueryKernelSwapPerform(const AnfNodePtr &kern
|
|
|
|
|
return iter_output->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<MemSwapInfo> &MemSwapManager::QueryKerneMemSwapInfo(const AnfNodePtr &kernel) const {
|
|
|
|
|
const std::vector<MemSwapInfo> &MemSwapManager::QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel);
|
|
|
|
|
auto iter = mem_swap_info_.find(kernel.get());
|
|
|
|
|
if (iter == mem_swap_info_.end()) {
|
|
|
|
|