diff --git a/mindspore/ccsrc/ps/CMakeLists.txt b/mindspore/ccsrc/ps/CMakeLists.txt index a2b292c226..5fafca3f1f 100644 --- a/mindspore/ccsrc/ps/CMakeLists.txt +++ b/mindspore/ccsrc/ps/CMakeLists.txt @@ -24,6 +24,7 @@ if (NOT ENABLE_GPU) list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc") endif() +list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc") list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc") add_subdirectory(ps_cache) diff --git a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc new file mode 100755 index 0000000000..00c020ef95 --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/ps_cache/embedding_hash_map.h" + +namespace mindspore { +namespace ps { +int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, + const size_t graph_running_step, size_t *swap_out_size) { + MS_EXCEPTION_IF_NULL(swap_out_index); + MS_EXCEPTION_IF_NULL(swap_out_ids); + MS_EXCEPTION_IF_NULL(swap_out_size); + auto hash_index = Hash(id); + auto need_swap = NeedSwap(); + size_t loop = 0; + while (true) { + if (loop++ == hash_capacity_) { + return INVALID_INDEX_VALUE; + } + if (hash_map_unit_[hash_index].IsEmpty()) { + hash_count_++; + (void)hash_id_to_index_.emplace(id, hash_index); + hash_map_unit_[hash_index].set_id(id); + hash_map_unit_[hash_index].set_step(data_step); + return hash_index; + } else if (need_swap && hash_map_unit_[hash_index].IsExpired(graph_running_step)) { + // Need swap out from the hash table. + swap_out_index[*swap_out_size] = hash_index; + swap_out_ids[*swap_out_size] = hash_map_unit_[hash_index].id_; + (*swap_out_size)++; + (void)hash_id_to_index_.erase(hash_map_unit_[hash_index].id_); + (void)hash_id_to_index_.emplace(id, hash_index); + hash_map_unit_[hash_index].set_id(id); + hash_map_unit_[hash_index].set_step(data_step); + return hash_index; + } + hash_index = (hash_index + 1) % hash_capacity_; + } +} + +void EmbeddingHashMap::DumpHashMap() { + MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_; + MS_LOG(INFO) << "Dump hash_id_to_index: "; + for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) { + MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second; + } + MS_LOG(INFO) << "Dump hash_map_unit: "; + for (size_t i = 0; i < hash_map_unit_.size(); i++) { + if (!hash_map_unit_[i].IsEmpty()) { + MS_LOG(INFO) << " index: " << i << " id: " << hash_map_unit_[i].id_ << " step: " << hash_map_unit_[i].step_; + } + } + MS_LOG(INFO) << "Dump hash map info end."; +} +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h new file mode 100644 index 0000000000..8df58c0927 --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ +#define MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ + +#include +#include +#include +#include +#include +#include "utils/convert_utils_base.h" + +namespace mindspore { +namespace ps { +static const size_t INVALID_STEP_VALUE = 0; +static const int INVALID_INDEX_VALUE = -1; + +struct HashMapElement { + int id_; + size_t step_; + bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } + bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } + void set_id(int id) { id_ = id; } + void set_step(size_t step) { step_ = step; } +}; + +// Hash table is held in device, HashMap is used to manage hash table in host. +class EmbeddingHashMap { + public: + EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { + hash_map_unit_.resize(hash_capacity); + } + virtual ~EmbeddingHashMap() = default; + int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, + const size_t graph_running_step, size_t *swap_out_size); + std::unordered_map::const_iterator id_iter(const int id) const { return hash_id_to_index_.find(id); } + bool IsIdExist(const std::unordered_map::const_iterator iter) const { + return iter != hash_id_to_index_.end(); + } + size_t hash_step(const int hash_index) const { return hash_map_unit_[hash_index].step_; } + void set_hash_step(const int hash_index, const size_t step) { hash_map_unit_[hash_index].set_step(step); } + void DumpHashMap(); + + private: + int Hash(const int id) { return static_cast((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); } + bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } + size_t hash_count_; + size_t hash_capacity_; + std::vector hash_map_unit_; + std::unordered_map hash_id_to_index_; +}; +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_ diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc new file mode 100644 index 0000000000..33cb35c44d --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace ps { +void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) { + if (cache_enable_ == false) { + return; + } + MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ")."; + auto iter = ps_data_channel_map_.find(channel_name); + if (iter != ps_data_channel_map_.end()) { + MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name; + auto channel = iter->second; + MS_EXCEPTION_IF_NULL(channel); + channel->set_step_num(step_num); + } else { + auto channel = std::make_shared(channel_name, step_num); + MS_EXCEPTION_IF_NULL(channel); + (void)ps_data_channel_map_.emplace(channel_name, channel); + } +} + +std::shared_ptr PsDataPrefetch::ps_data_channel(const std::string &channel_name) const { + auto iter = ps_data_channel_map_.find(channel_name); + if (iter == ps_data_channel_map_.end()) { + MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name; + } + return iter->second; +} + +void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { + if (cache_enable_ == false) { + return; + } + if (data == nullptr) { + MS_LOG(WARNING) << "No data prefetch."; + return; + } + auto channel = ps_data_channel(channel_name); + MS_EXCEPTION_IF_NULL(channel); + channel->set_data(data, data_size); + std::unique_lock locker(data_mutex_); + data_ready_ = true; + data_process_.notify_one(); + for (int i = 0; i < 10; i++) { + if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) { + return; + } else { + MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)"; + } + } + MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size."; +} + +void PsDataPrefetch::FinalizeData(const std::string &channel_name) { + if (cache_enable_ == false) { + return; + } + auto channel = ps_data_channel(channel_name); + MS_EXCEPTION_IF_NULL(channel); + channel->ResetData(); + std::unique_lock locker(data_mutex_); + data_ready_ = false; + data_prefetch_.notify_one(); + for (int i = 0; i < 10; i++) { + if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) { + return; + } else { + MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)"; + } + } + MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout."; +} + +void *PsDataPrefetch::data(const std::string &channel_name) const { + auto channel = ps_data_channel(channel_name); + MS_EXCEPTION_IF_NULL(channel); + return channel->data(); +} + +size_t PsDataPrefetch::data_size(const std::string &channel_name) const { + auto channel = ps_data_channel(channel_name); + MS_EXCEPTION_IF_NULL(channel); + return channel->data_size(); +} + +void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { + auto channel = ps_data_channel(channel_name); + MS_EXCEPTION_IF_NULL(channel); + channel->TryWakeChannel(); +} +} // namespace ps +} // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h new file mode 100644 index 0000000000..f4c00bfb90 --- /dev/null +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_ +#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_ + +#include +#include +#include +#include +#include "ps/ps_cache/ps_data/ps_data_channel.h" + +#define EXPORT __attribute__((visibility("default"))) + +namespace mindspore { +namespace ps { +class PsDataPrefetch { + public: + EXPORT static PsDataPrefetch &GetInstance() { + static PsDataPrefetch instance; + return instance; + } + + EXPORT bool cache_enable() const { return cache_enable_; } + EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } + EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); + EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size); + EXPORT void FinalizeData(const std::string &channel_name); + EXPORT void *data(const std::string &channel_name) const; + EXPORT size_t data_size(const std::string &channel_name) const; + EXPORT void TryWakeChannel(const std::string &channel_name); + + private: + PsDataPrefetch() : cache_enable_(false), data_ready_(false) {} + virtual ~PsDataPrefetch() = default; + PsDataPrefetch(const PsDataPrefetch &) = delete; + PsDataPrefetch &operator=(const PsDataPrefetch &) = delete; + std::shared_ptr ps_data_channel(const std::string &channel_name) const; + std::map> ps_data_channel_map_; + bool cache_enable_; + bool data_ready_; + std::mutex data_mutex_; + std::condition_variable data_prefetch_; + std::condition_variable data_process_; +}; +} // namespace ps +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_