|
|
|
@ -32,7 +32,11 @@ namespace memswap {
|
|
|
|
|
class MemSwapManager {
|
|
|
|
|
public:
|
|
|
|
|
explicit MemSwapManager(const MemCopyManagerPtr &mem_copy_manager)
|
|
|
|
|
: tensor_size_threshold_(0), tensor_size_threshold_idx_(0), tensor_size_num_(1), distance_threshold_(1) {
|
|
|
|
|
: tensor_size_threshold_(0),
|
|
|
|
|
tensor_size_threshold_idx_(0),
|
|
|
|
|
tensor_size_num_(1),
|
|
|
|
|
distance_threshold_(1),
|
|
|
|
|
distance_decay_step_(1) {
|
|
|
|
|
mem_copy_manager_ = mem_copy_manager;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -42,7 +46,7 @@ class MemSwapManager {
|
|
|
|
|
|
|
|
|
|
~MemSwapManager() = default;
|
|
|
|
|
|
|
|
|
|
void Init(const mindspore::session::KernelGraph *kernel_graph);
|
|
|
|
|
bool Init(const mindspore::session::KernelGraph *kernel_graph, size_t swap_mem_size = 0);
|
|
|
|
|
|
|
|
|
|
void AddMemSwapTask(SwapKind swap_kind, const DeviceAddressPtr &device_address,
|
|
|
|
|
const HostAddress &host_address) const;
|
|
|
|
@ -51,9 +55,10 @@ class MemSwapManager {
|
|
|
|
|
|
|
|
|
|
DeviceAddressPtr UpdateSwapQueue(SwapKind swap_kind) const;
|
|
|
|
|
|
|
|
|
|
// retreat to find a workable swap scheme
|
|
|
|
|
bool RetreatSwapInfo();
|
|
|
|
|
|
|
|
|
|
void AdjustSwapInPos(const AnfNodePtr &kernel, size_t index);
|
|
|
|
|
|
|
|
|
|
bool trigger_swap() const { return trigger_swap_; }
|
|
|
|
|
|
|
|
|
|
bool mem_swap_init() const { return mem_swap_initialized_; }
|
|
|
|
@ -70,16 +75,28 @@ class MemSwapManager {
|
|
|
|
|
|
|
|
|
|
bool QueryKernelTriggerSwap(const AnfNodePtr &kernel) const;
|
|
|
|
|
|
|
|
|
|
bool QueryKernelNeedSwap(const AnfNodePtr &kernel) const;
|
|
|
|
|
bool QueryKernelTriggerSwapIn(const AnfNodePtr &kernel) const;
|
|
|
|
|
|
|
|
|
|
size_t QueryKernelTriggerSwapInTaskNum(const AnfNodePtr &kernel) const;
|
|
|
|
|
|
|
|
|
|
const AnfNodePtr QueryKerneByTopoOrder(size_t index) const;
|
|
|
|
|
|
|
|
|
|
const MemSwapInfoSet &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const;
|
|
|
|
|
|
|
|
|
|
void AssignHostMemory();
|
|
|
|
|
|
|
|
|
|
const std::vector<MemSwapInfo> &QueryKernelMemSwapInfo(const AnfNodePtr &kernel) const;
|
|
|
|
|
const HostAddress &QueryKernelHostAddr(const AnfNodePtr &kernel, size_t output_idx) const;
|
|
|
|
|
|
|
|
|
|
void AddKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx, bool dirty);
|
|
|
|
|
|
|
|
|
|
bool QueryKernelHostAddrIsDirty(const AnfNodePtr &kernel, size_t output_idx) const;
|
|
|
|
|
|
|
|
|
|
void ResetHostAddrIsDirty();
|
|
|
|
|
|
|
|
|
|
void InsertSwapInBlackList(const void *device_ptr);
|
|
|
|
|
|
|
|
|
|
bool FindInSwapInBlackList(const void *device_ptr) const;
|
|
|
|
|
|
|
|
|
|
const HostAddress &kernel_host_addr(const AnfNodePtr &kernel, size_t output_idx) const;
|
|
|
|
|
|
|
|
|
|
bool AllocHostPinnedMem(size_t size, void **addr) const;
|
|
|
|
|
|
|
|
|
|
void ReleaseHostPinnedMem();
|
|
|
|
@ -93,27 +110,47 @@ class MemSwapManager {
|
|
|
|
|
|
|
|
|
|
void SaveUserKernelTopoOrder();
|
|
|
|
|
|
|
|
|
|
void AddKernelTriggerSwap(const AnfNodePtr &kernel, bool trigger_swap);
|
|
|
|
|
bool InitSwapThreshold(size_t swap_mem_size);
|
|
|
|
|
|
|
|
|
|
void AddKernelNeedSwap(const AnfNodePtr &kernel, bool need_swap);
|
|
|
|
|
void RetreatSwapThreshold();
|
|
|
|
|
|
|
|
|
|
void CacheCurSwapInfoSet(const AnfNodePtr &kernel);
|
|
|
|
|
|
|
|
|
|
void AddFirstTimeMovePos(const AnfNodePtr &kernel, size_t index, bool first_time);
|
|
|
|
|
|
|
|
|
|
bool QueryFirstTimeMovePos(const AnfNodePtr &kernel, size_t index) const;
|
|
|
|
|
|
|
|
|
|
size_t BestSwapInPerformPos(const AnfNodePtr &trigger_kernel, const MemSwapInfo &mem_swap_info) const;
|
|
|
|
|
|
|
|
|
|
void MoveSwapInfoPos(size_t des_pos, size_t src_pos, const MemSwapInfo &mem_swap_info);
|
|
|
|
|
|
|
|
|
|
void AddKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info);
|
|
|
|
|
|
|
|
|
|
void RemoveKernelMemSwapInfo(const AnfNodePtr &kernel, const MemSwapInfo &mem_swap_info);
|
|
|
|
|
|
|
|
|
|
bool CheckDistanceBetweenKernels(const TensorInfo &tensor_info) const;
|
|
|
|
|
|
|
|
|
|
bool IsCommunicationRelevantOp(const AnfNodePtr &kernel) const;
|
|
|
|
|
|
|
|
|
|
std::vector<CNodePtr> execution_order_;
|
|
|
|
|
std::vector<TensorInfo> ordered_tensors_;
|
|
|
|
|
std::unordered_map<void *, KernelExecutionInfo> kernel_execution_info_;
|
|
|
|
|
std::unordered_map<void *, std::map<size_t, PerformPair>> kernel_swap_perform_;
|
|
|
|
|
// trigger swap kernel key : MemSwapInfo of kernel need to be swapped
|
|
|
|
|
std::unordered_map<void *, std::vector<MemSwapInfo>> mem_swap_info_;
|
|
|
|
|
// Key: trigger swap kernel, value: MemSwapInfoSet of kernel need to be swapped
|
|
|
|
|
std::unordered_map<void *, MemSwapInfoSet> mem_swap_info_map_;
|
|
|
|
|
std::vector<HostAddress> host_addrs_list_;
|
|
|
|
|
std::unordered_set<const void *> swap_in_blacklist_;
|
|
|
|
|
|
|
|
|
|
// Key: cache kernel address, value: lists of first time move pos or not
|
|
|
|
|
std::map<void *, std::vector<bool>> kernel_first_move_cache_map_;
|
|
|
|
|
std::vector<MemSwapInfo> mem_swap_info_cache_list_;
|
|
|
|
|
std::pair<size_t, size_t> best_and_cur_pos_cache_;
|
|
|
|
|
|
|
|
|
|
size_t tensor_size_threshold_;
|
|
|
|
|
size_t tensor_size_threshold_idx_;
|
|
|
|
|
size_t tensor_size_num_;
|
|
|
|
|
size_t distance_threshold_;
|
|
|
|
|
size_t distance_decay_step_;
|
|
|
|
|
|
|
|
|
|
MemCopyManagerPtr mem_copy_manager_{nullptr};
|
|
|
|
|
FuncGraphManagerPtr graph_manager_{nullptr};
|
|
|
|
|