add ps cache

pull/9675/head
limingqi107 4 years ago
parent a5f57ce8a0
commit 660a087ffd

@ -24,6 +24,7 @@ if (NOT ENABLE_GPU)
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/gpu/gpu_ps_cache.cc")
endif()
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_prefetch.cc")
list(REMOVE_ITEM _PS_SRC_FILES "ps_cache/ps_data/ps_data_channel.cc")
add_subdirectory(ps_cache)

@ -0,0 +1,69 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/ps_cache/embedding_hash_map.h"
namespace mindspore {
namespace ps {
int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
const size_t graph_running_step, size_t *swap_out_size) {
MS_EXCEPTION_IF_NULL(swap_out_index);
MS_EXCEPTION_IF_NULL(swap_out_ids);
MS_EXCEPTION_IF_NULL(swap_out_size);
auto hash_index = Hash(id);
auto need_swap = NeedSwap();
size_t loop = 0;
while (true) {
if (loop++ == hash_capacity_) {
return INVALID_INDEX_VALUE;
}
if (hash_map_unit_[hash_index].IsEmpty()) {
hash_count_++;
(void)hash_id_to_index_.emplace(id, hash_index);
hash_map_unit_[hash_index].set_id(id);
hash_map_unit_[hash_index].set_step(data_step);
return hash_index;
} else if (need_swap && hash_map_unit_[hash_index].IsExpired(graph_running_step)) {
// Need swap out from the hash table.
swap_out_index[*swap_out_size] = hash_index;
swap_out_ids[*swap_out_size] = hash_map_unit_[hash_index].id_;
(*swap_out_size)++;
(void)hash_id_to_index_.erase(hash_map_unit_[hash_index].id_);
(void)hash_id_to_index_.emplace(id, hash_index);
hash_map_unit_[hash_index].set_id(id);
hash_map_unit_[hash_index].set_step(data_step);
return hash_index;
}
hash_index = (hash_index + 1) % hash_capacity_;
}
}
void EmbeddingHashMap::DumpHashMap() {
MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_;
MS_LOG(INFO) << "Dump hash_id_to_index: ";
for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) {
MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second;
}
MS_LOG(INFO) << "Dump hash_map_unit: ";
for (size_t i = 0; i < hash_map_unit_.size(); i++) {
if (!hash_map_unit_[i].IsEmpty()) {
MS_LOG(INFO) << " index: " << i << " id: " << hash_map_unit_[i].id_ << " step: " << hash_map_unit_[i].step_;
}
}
MS_LOG(INFO) << "Dump hash map info end.";
}
} // namespace ps
} // namespace mindspore

@ -0,0 +1,68 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_
#define MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_
#include <math.h>
#include <utility>
#include <memory>
#include <vector>
#include <unordered_map>
#include "utils/convert_utils_base.h"
namespace mindspore {
namespace ps {
static const size_t INVALID_STEP_VALUE = 0;
static const int INVALID_INDEX_VALUE = -1;
struct HashMapElement {
int id_;
size_t step_;
bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; }
bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
void set_id(int id) { id_ = id; }
void set_step(size_t step) { step_ = step; }
};
// Hash table is held in device, HashMap is used to manage hash table in host.
class EmbeddingHashMap {
public:
EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) {
hash_map_unit_.resize(hash_capacity);
}
virtual ~EmbeddingHashMap() = default;
int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
const size_t graph_running_step, size_t *swap_out_size);
std::unordered_map<int, int>::const_iterator id_iter(const int id) const { return hash_id_to_index_.find(id); }
bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const {
return iter != hash_id_to_index_.end();
}
size_t hash_step(const int hash_index) const { return hash_map_unit_[hash_index].step_; }
void set_hash_step(const int hash_index, const size_t step) { hash_map_unit_[hash_index].set_step(step); }
void DumpHashMap();
private:
int Hash(const int id) { return static_cast<int>((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); }
bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); }
size_t hash_count_;
size_t hash_capacity_;
std::vector<HashMapElement> hash_map_unit_;
std::unordered_map<int, int> hash_id_to_index_;
};
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_

@ -0,0 +1,110 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace ps {
void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) {
if (cache_enable_ == false) {
return;
}
MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ").";
auto iter = ps_data_channel_map_.find(channel_name);
if (iter != ps_data_channel_map_.end()) {
MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name;
auto channel = iter->second;
MS_EXCEPTION_IF_NULL(channel);
channel->set_step_num(step_num);
} else {
auto channel = std::make_shared<PsDataChannel>(channel_name, step_num);
MS_EXCEPTION_IF_NULL(channel);
(void)ps_data_channel_map_.emplace(channel_name, channel);
}
}
std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const {
auto iter = ps_data_channel_map_.find(channel_name);
if (iter == ps_data_channel_map_.end()) {
MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name;
}
return iter->second;
}
void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
if (cache_enable_ == false) {
return;
}
if (data == nullptr) {
MS_LOG(WARNING) << "No data prefetch.";
return;
}
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
channel->set_data(data, data_size);
std::unique_lock<std::mutex> locker(data_mutex_);
data_ready_ = true;
data_process_.notify_one();
for (int i = 0; i < 10; i++) {
if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) {
return;
} else {
MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)";
}
}
MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size.";
}
void PsDataPrefetch::FinalizeData(const std::string &channel_name) {
if (cache_enable_ == false) {
return;
}
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
channel->ResetData();
std::unique_lock<std::mutex> locker(data_mutex_);
data_ready_ = false;
data_prefetch_.notify_one();
for (int i = 0; i < 10; i++) {
if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) {
return;
} else {
MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)";
}
}
MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout.";
}
void *PsDataPrefetch::data(const std::string &channel_name) const {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
return channel->data();
}
size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
return channel->data_size();
}
void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) {
auto channel = ps_data_channel(channel_name);
MS_EXCEPTION_IF_NULL(channel);
channel->TryWakeChannel();
}
} // namespace ps
} // namespace mindspore

@ -0,0 +1,60 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_
#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_
#include <map>
#include <string>
#include <memory>
#include <condition_variable>
#include "ps/ps_cache/ps_data/ps_data_channel.h"
#define EXPORT __attribute__((visibility("default")))
namespace mindspore {
namespace ps {
class PsDataPrefetch {
public:
EXPORT static PsDataPrefetch &GetInstance() {
static PsDataPrefetch instance;
return instance;
}
EXPORT bool cache_enable() const { return cache_enable_; }
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num);
EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
EXPORT void FinalizeData(const std::string &channel_name);
EXPORT void *data(const std::string &channel_name) const;
EXPORT size_t data_size(const std::string &channel_name) const;
EXPORT void TryWakeChannel(const std::string &channel_name);
private:
PsDataPrefetch() : cache_enable_(false), data_ready_(false) {}
virtual ~PsDataPrefetch() = default;
PsDataPrefetch(const PsDataPrefetch &) = delete;
PsDataPrefetch &operator=(const PsDataPrefetch &) = delete;
std::shared_ptr<PsDataChannel> ps_data_channel(const std::string &channel_name) const;
std::map<std::string, std::shared_ptr<PsDataChannel>> ps_data_channel_map_;
bool cache_enable_;
bool data_ready_;
std::mutex data_mutex_;
std::condition_variable data_prefetch_;
std::condition_variable data_process_;
};
} // namespace ps
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_
Loading…
Cancel
Save