parent
a5f57ce8a0
commit
660a087ffd
@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/ps_cache/embedding_hash_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
|
||||
const size_t graph_running_step, size_t *swap_out_size) {
|
||||
MS_EXCEPTION_IF_NULL(swap_out_index);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_ids);
|
||||
MS_EXCEPTION_IF_NULL(swap_out_size);
|
||||
auto hash_index = Hash(id);
|
||||
auto need_swap = NeedSwap();
|
||||
size_t loop = 0;
|
||||
while (true) {
|
||||
if (loop++ == hash_capacity_) {
|
||||
return INVALID_INDEX_VALUE;
|
||||
}
|
||||
if (hash_map_unit_[hash_index].IsEmpty()) {
|
||||
hash_count_++;
|
||||
(void)hash_id_to_index_.emplace(id, hash_index);
|
||||
hash_map_unit_[hash_index].set_id(id);
|
||||
hash_map_unit_[hash_index].set_step(data_step);
|
||||
return hash_index;
|
||||
} else if (need_swap && hash_map_unit_[hash_index].IsExpired(graph_running_step)) {
|
||||
// Need swap out from the hash table.
|
||||
swap_out_index[*swap_out_size] = hash_index;
|
||||
swap_out_ids[*swap_out_size] = hash_map_unit_[hash_index].id_;
|
||||
(*swap_out_size)++;
|
||||
(void)hash_id_to_index_.erase(hash_map_unit_[hash_index].id_);
|
||||
(void)hash_id_to_index_.emplace(id, hash_index);
|
||||
hash_map_unit_[hash_index].set_id(id);
|
||||
hash_map_unit_[hash_index].set_step(data_step);
|
||||
return hash_index;
|
||||
}
|
||||
hash_index = (hash_index + 1) % hash_capacity_;
|
||||
}
|
||||
}
|
||||
|
||||
void EmbeddingHashMap::DumpHashMap() {
|
||||
MS_LOG(INFO) << "Dump hash map info begin, hash_capacity: " << hash_capacity_ << " hash_count: " << hash_count_;
|
||||
MS_LOG(INFO) << "Dump hash_id_to_index: ";
|
||||
for (auto iter = hash_id_to_index_.begin(); iter != hash_id_to_index_.end(); ++iter) {
|
||||
MS_LOG(INFO) << " id: " << iter->first << " index: " << iter->second;
|
||||
}
|
||||
MS_LOG(INFO) << "Dump hash_map_unit: ";
|
||||
for (size_t i = 0; i < hash_map_unit_.size(); i++) {
|
||||
if (!hash_map_unit_[i].IsEmpty()) {
|
||||
MS_LOG(INFO) << " index: " << i << " id: " << hash_map_unit_[i].id_ << " step: " << hash_map_unit_[i].step_;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Dump hash map info end.";
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
@ -0,0 +1,68 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_
|
||||
|
||||
#include <math.h>
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include "utils/convert_utils_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
static const size_t INVALID_STEP_VALUE = 0;
|
||||
static const int INVALID_INDEX_VALUE = -1;
|
||||
|
||||
struct HashMapElement {
|
||||
int id_;
|
||||
size_t step_;
|
||||
bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; }
|
||||
bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
|
||||
void set_id(int id) { id_ = id; }
|
||||
void set_step(size_t step) { step_ = step; }
|
||||
};
|
||||
|
||||
// Hash table is held in device, HashMap is used to manage hash table in host.
|
||||
class EmbeddingHashMap {
|
||||
public:
|
||||
EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) {
|
||||
hash_map_unit_.resize(hash_capacity);
|
||||
}
|
||||
virtual ~EmbeddingHashMap() = default;
|
||||
int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
|
||||
const size_t graph_running_step, size_t *swap_out_size);
|
||||
std::unordered_map<int, int>::const_iterator id_iter(const int id) const { return hash_id_to_index_.find(id); }
|
||||
bool IsIdExist(const std::unordered_map<int, int>::const_iterator iter) const {
|
||||
return iter != hash_id_to_index_.end();
|
||||
}
|
||||
size_t hash_step(const int hash_index) const { return hash_map_unit_[hash_index].step_; }
|
||||
void set_hash_step(const int hash_index, const size_t step) { hash_map_unit_[hash_index].set_step(step); }
|
||||
void DumpHashMap();
|
||||
|
||||
private:
|
||||
int Hash(const int id) { return static_cast<int>((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); }
|
||||
bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); }
|
||||
size_t hash_count_;
|
||||
size_t hash_capacity_;
|
||||
std::vector<HashMapElement> hash_map_unit_;
|
||||
std::unordered_map<int, int> hash_id_to_index_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_EMBEDDING_HASH_MAP_H_
|
@ -0,0 +1,110 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ps/ps_cache/ps_data/ps_data_prefetch.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t step_num) {
|
||||
if (cache_enable_ == false) {
|
||||
return;
|
||||
}
|
||||
MS_LOG(INFO) << "PS cache creates data channel(channel name:" << channel_name << ", step num:" << step_num << ").";
|
||||
auto iter = ps_data_channel_map_.find(channel_name);
|
||||
if (iter != ps_data_channel_map_.end()) {
|
||||
MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name;
|
||||
auto channel = iter->second;
|
||||
MS_EXCEPTION_IF_NULL(channel);
|
||||
channel->set_step_num(step_num);
|
||||
} else {
|
||||
auto channel = std::make_shared<PsDataChannel>(channel_name, step_num);
|
||||
MS_EXCEPTION_IF_NULL(channel);
|
||||
(void)ps_data_channel_map_.emplace(channel_name, channel);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<PsDataChannel> PsDataPrefetch::ps_data_channel(const std::string &channel_name) const {
|
||||
auto iter = ps_data_channel_map_.find(channel_name);
|
||||
if (iter == ps_data_channel_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) {
|
||||
if (cache_enable_ == false) {
|
||||
return;
|
||||
}
|
||||
if (data == nullptr) {
|
||||
MS_LOG(WARNING) << "No data prefetch.";
|
||||
return;
|
||||
}
|
||||
auto channel = ps_data_channel(channel_name);
|
||||
MS_EXCEPTION_IF_NULL(channel);
|
||||
channel->set_data(data, data_size);
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_ready_ = true;
|
||||
data_process_.notify_one();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) {
|
||||
return;
|
||||
} else {
|
||||
MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)";
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size.";
|
||||
}
|
||||
|
||||
void PsDataPrefetch::FinalizeData(const std::string &channel_name) {
|
||||
if (cache_enable_ == false) {
|
||||
return;
|
||||
}
|
||||
auto channel = ps_data_channel(channel_name);
|
||||
MS_EXCEPTION_IF_NULL(channel);
|
||||
channel->ResetData();
|
||||
std::unique_lock<std::mutex> locker(data_mutex_);
|
||||
data_ready_ = false;
|
||||
data_prefetch_.notify_one();
|
||||
for (int i = 0; i < 10; i++) {
|
||||
if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) {
|
||||
return;
|
||||
} else {
|
||||
MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)";
|
||||
}
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout.";
|
||||
}
|
||||
|
||||
void *PsDataPrefetch::data(const std::string &channel_name) const {
|
||||
auto channel = ps_data_channel(channel_name);
|
||||
MS_EXCEPTION_IF_NULL(channel);
|
||||
return channel->data();
|
||||
}
|
||||
|
||||
size_t PsDataPrefetch::data_size(const std::string &channel_name) const {
|
||||
auto channel = ps_data_channel(channel_name);
|
||||
MS_EXCEPTION_IF_NULL(channel);
|
||||
return channel->data_size();
|
||||
}
|
||||
|
||||
void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) {
|
||||
auto channel = ps_data_channel(channel_name);
|
||||
MS_EXCEPTION_IF_NULL(channel);
|
||||
channel->TryWakeChannel();
|
||||
}
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_
|
||||
#define MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <condition_variable>
|
||||
#include "ps/ps_cache/ps_data/ps_data_channel.h"
|
||||
|
||||
#define EXPORT __attribute__((visibility("default")))
|
||||
|
||||
namespace mindspore {
|
||||
namespace ps {
|
||||
class PsDataPrefetch {
|
||||
public:
|
||||
EXPORT static PsDataPrefetch &GetInstance() {
|
||||
static PsDataPrefetch instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
EXPORT bool cache_enable() const { return cache_enable_; }
|
||||
EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
|
||||
EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num);
|
||||
EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size);
|
||||
EXPORT void FinalizeData(const std::string &channel_name);
|
||||
EXPORT void *data(const std::string &channel_name) const;
|
||||
EXPORT size_t data_size(const std::string &channel_name) const;
|
||||
EXPORT void TryWakeChannel(const std::string &channel_name);
|
||||
|
||||
private:
|
||||
PsDataPrefetch() : cache_enable_(false), data_ready_(false) {}
|
||||
virtual ~PsDataPrefetch() = default;
|
||||
PsDataPrefetch(const PsDataPrefetch &) = delete;
|
||||
PsDataPrefetch &operator=(const PsDataPrefetch &) = delete;
|
||||
std::shared_ptr<PsDataChannel> ps_data_channel(const std::string &channel_name) const;
|
||||
std::map<std::string, std::shared_ptr<PsDataChannel>> ps_data_channel_map_;
|
||||
bool cache_enable_;
|
||||
bool data_ready_;
|
||||
std::mutex data_mutex_;
|
||||
std::condition_variable data_prefetch_;
|
||||
std::condition_variable data_process_;
|
||||
};
|
||||
} // namespace ps
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PS_PS_CACHE_PS_DATA_PS_DATA_PREFETCH_H_
|
Loading…
Reference in new issue