|
|
|
@ -24,7 +24,15 @@ namespace device {
|
|
|
|
|
namespace memswap {
|
|
|
|
|
void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
execution_order_ = kernel_graph->execution_order();
|
|
|
|
|
graph_manager_ = kernel_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph_manager_);
|
|
|
|
|
auto &kernels = kernel_graph->execution_order();
|
|
|
|
|
for (const auto &kernel : kernels) {
|
|
|
|
|
if (AnfAlgo::IsRealCNodeKernel(kernel) && (!opt::IsNopNode(kernel))) {
|
|
|
|
|
execution_order_.push_back(kernel);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t kernel_index = 0;
|
|
|
|
|
for (const auto &kernel : execution_order_) {
|
|
|
|
|
// parse topo order of kernel
|
|
|
|
@ -41,7 +49,7 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// parse topo order of user kernel
|
|
|
|
|
SaveUserKernelTopoOrder(kernel_graph);
|
|
|
|
|
SaveUserKernelTopoOrder();
|
|
|
|
|
|
|
|
|
|
sort(ordered_tensors_.begin(), ordered_tensors_.end(),
|
|
|
|
|
[](const TensorInfo &a, const TensorInfo &b) { return a.tensor_size_ > b.tensor_size_; });
|
|
|
|
@ -62,11 +70,22 @@ void MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
mem_copy_manager_->Init();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGraph *kernel_graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
FuncGraphManagerPtr manager = kernel_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
NodeUsersMap user_map = manager->node_users();
|
|
|
|
|
bool MemSwapManager::IsCommunicationRelevantOp(const AnfNodePtr &kernel) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel);
|
|
|
|
|
NodeUsersMap &user_map = graph_manager_->node_users();
|
|
|
|
|
auto iter = user_map.find(kernel);
|
|
|
|
|
bool adjacent_with_communication_op = false;
|
|
|
|
|
if (iter != user_map.end()) {
|
|
|
|
|
AnfNodeIndexSet node_set = iter->second;
|
|
|
|
|
adjacent_with_communication_op = std::any_of(
|
|
|
|
|
node_set.begin(), node_set.end(),
|
|
|
|
|
[](const std::pair<AnfNodePtr, int> &node_pair) { return AnfAlgo::IsCommunicationOp(node_pair.first); });
|
|
|
|
|
}
|
|
|
|
|
return (AnfAlgo::IsCommunicationOp(kernel)) || adjacent_with_communication_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MemSwapManager::SaveUserKernelTopoOrder() {
|
|
|
|
|
NodeUsersMap &user_map = graph_manager_->node_users();
|
|
|
|
|
for (const auto &kernel : execution_order_) {
|
|
|
|
|
auto iter = user_map.find(kernel);
|
|
|
|
|
if (iter == user_map.end()) {
|
|
|
|
@ -76,13 +95,16 @@ void MemSwapManager::SaveUserKernelTopoOrder(const mindspore::session::KernelGra
|
|
|
|
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
|
|
|
|
for (auto &node_pair : node_set) {
|
|
|
|
|
auto user_kernel = node_pair.first;
|
|
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(user_kernel)) {
|
|
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(user_kernel) || opt::IsNopNode(user_kernel)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t user_kernel_topo_sort = SearchKernelExecutionInfo(user_kernel).topo_order_;
|
|
|
|
|
auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(user_kernel, node_pair.second - 1);
|
|
|
|
|
auto &output_idx = kernel_with_index.second;
|
|
|
|
|
if (kernel_with_index.first.get() != kernel.get()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Save user kernel topo order failed for op[" << AnfAlgo::GetCNodeName(kernel) << "]";
|
|
|
|
|
}
|
|
|
|
|
kernel_exec_info.node_users_map_[output_idx].push_back(user_kernel_topo_sort);
|
|
|
|
|
}
|
|
|
|
|
for (auto &node_user_pair : kernel_exec_info.node_users_map_) {
|
|
|
|
@ -100,6 +122,9 @@ void MemSwapManager::AddSwapInfo() {
|
|
|
|
|
|
|
|
|
|
size_t output_idx = tensor.output_idx_;
|
|
|
|
|
const AnfNodePtr &kernel = tensor.kernel_;
|
|
|
|
|
if (IsCommunicationRelevantOp(kernel)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto &kernel_exec_info = SearchKernelExecutionInfo(kernel);
|
|
|
|
|
auto &node_users_map = kernel_exec_info.node_users_map_;
|
|
|
|
|
|
|
|
|
@ -178,7 +203,7 @@ bool MemSwapManager::RetreatSwapInfo() {
|
|
|
|
|
|
|
|
|
|
while (tensor_size_threshold_idx_ < ordered_tensors_.size() - 1) {
|
|
|
|
|
++tensor_size_threshold_idx_;
|
|
|
|
|
if (tensor_size_threshold_idx_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) {
|
|
|
|
|
if (tensor_size_threshold_ > ordered_tensors_[tensor_size_threshold_idx_].tensor_size_) {
|
|
|
|
|
tensor_size_threshold_ = ordered_tensors_[tensor_size_threshold_idx_].tensor_size_;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|