!11457 PS cache insert linear

From: @gaoyong10
Reviewed-by: 
Signed-off-by:
pull/11457/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 16944856f9

@ -19,36 +19,66 @@
namespace mindspore {
namespace ps {
int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
const size_t graph_running_step, size_t *swap_out_size) {
const size_t graph_running_step, size_t *swap_out_size, bool *need_wait_graph) {
MS_EXCEPTION_IF_NULL(swap_out_index);
MS_EXCEPTION_IF_NULL(swap_out_ids);
MS_EXCEPTION_IF_NULL(swap_out_size);
auto hash_index = Hash(id);
auto need_swap = NeedSwap();
size_t loop = 0;
while (true) {
if (loop++ == hash_capacity_) {
return INVALID_INDEX_VALUE;
}
if (hash_map_elements_[hash_index].IsEmpty()) {
bool need_swap = false;
auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph);
if (hash_index == INVALID_INDEX_VALUE) {
return hash_index;
}
if (!need_swap) {
hash_count_++;
(void)hash_id_to_index_.emplace(id, hash_index);
hash_map_elements_[hash_index].set_id(id);
hash_map_elements_[hash_index].set_step(data_step);
return hash_index;
}
swap_out_index[*swap_out_size] = hash_index;
swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_;
(*swap_out_size)++;
(void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_);
(void)hash_id_to_index_.emplace(id, hash_index);
hash_map_elements_[hash_index].set_id(id);
hash_map_elements_[hash_index].set_step(data_step);
return hash_index;
}
int EmbeddingHashMap::FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *need_swap,
bool *need_wait_graph) {
MS_EXCEPTION_IF_NULL(need_swap);
MS_EXCEPTION_IF_NULL(need_wait_graph);
int hash_index = INVALID_INDEX_VALUE;
while (!expired_element_full_) {
if (hash_map_elements_[current_pos_].IsEmpty()) {
hash_index = current_pos_;
hash_count_++;
(void)hash_id_to_index_.emplace(id, hash_index);
hash_map_elements_[hash_index].set_id(id);
hash_map_elements_[hash_index].set_step(data_step);
return hash_index;
} else if (need_swap && hash_map_elements_[hash_index].IsExpired(graph_running_step)) {
// Need swap out from the hash table.
swap_out_index[*swap_out_size] = hash_index;
swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_;
(*swap_out_size)++;
(void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_);
(void)hash_id_to_index_.emplace(id, hash_index);
hash_map_elements_[hash_index].set_id(id);
hash_map_elements_[hash_index].set_step(data_step);
} else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) {
hash_index = current_pos_;
*need_swap = true;
} else if (hash_map_elements_[current_pos_].IsStep(graph_running_step)) {
graph_running_index_[graph_running_index_num_++] = current_pos_;
}
current_pos_ = (current_pos_ + 1) % hash_capacity_;
if (hash_index != INVALID_INDEX_VALUE) {
return hash_index;
}
hash_index = (hash_index + 1) % hash_capacity_;
if (current_pos_ == current_batch_start_pos_) {
expired_element_full_ = true;
MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_
<< ") will be used, index swap will wait until the graph completed.";
}
}
if (graph_running_index_pos_ != graph_running_index_num_) {
*need_swap = true;
*need_wait_graph = true;
return graph_running_index_[graph_running_index_pos_++];
}
return INVALID_INDEX_VALUE;
}
void EmbeddingHashMap::DumpHashMap() {
@ -66,5 +96,12 @@ void EmbeddingHashMap::DumpHashMap() {
}
MS_LOG(INFO) << "Dump hash map info end.";
}
void EmbeddingHashMap::Reset() {
current_batch_start_pos_ = current_pos_;
graph_running_index_num_ = 0;
graph_running_index_pos_ = 0;
expired_element_full_ = false;
}
} // namespace ps
} // namespace mindspore

@ -34,6 +34,7 @@ struct HashMapElement {
size_t step_{INVALID_STEP_VALUE};
bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; }
bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; }
bool IsStep(size_t step) const { return step_ == step; }
void set_id(int id) { id_ = id; }
void set_step(size_t step) { step_ = step; }
};
@ -41,25 +42,39 @@ struct HashMapElement {
// Hash table is held in device, HashMap is used to manage hash table in host.
class EmbeddingHashMap {
public:
EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) {
EmbeddingHashMap(size_t hash_count, size_t hash_capacity)
: hash_count_(hash_count),
hash_capacity_(hash_capacity),
current_pos_(0),
current_batch_start_pos_(0),
graph_running_index_num_(0),
graph_running_index_pos_(0),
expired_element_full_(false) {
hash_map_elements_.resize(hash_capacity);
graph_running_index_ = std::make_unique<int[]>(hash_capacity);
}
virtual ~EmbeddingHashMap() = default;
int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step,
const size_t graph_running_step, size_t *swap_out_size);
const size_t graph_running_step, size_t *swap_out_size, bool *need_wait_graph);
size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; }
void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); }
const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; }
size_t hash_capacity() const { return hash_capacity_; }
void DumpHashMap();
void Reset();
private:
int Hash(const int id) { return static_cast<int>((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); }
bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); }
int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *need_swap, bool *need_wait_graph);
size_t hash_count_;
size_t hash_capacity_;
std::vector<HashMapElement> hash_map_elements_;
std::unordered_map<int, int> hash_id_to_index_;
size_t current_pos_;
size_t current_batch_start_pos_;
size_t graph_running_index_num_;
size_t graph_running_index_pos_;
std::unique_ptr<int[]> graph_running_index_;
bool expired_element_full_;
};
} // namespace ps
} // namespace mindspore

@ -369,6 +369,10 @@ bool PsCacheManager::ProcessData() {
// Get hash swap in/out index and ids.
RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get()));
DumpStatisticsInfo();
if ((device_need_wait_graph_ || host_need_wait_graph_) && (!WaitGraphRun())) {
MS_LOG(ERROR) << "Ps cache wait graph finish failed.";
return false;
}
for (const auto &item : hash_tables_) {
auto key = worker.GetParamKey(item.first);
auto hash_info = item.second;
@ -454,6 +458,20 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id
return true;
}
bool PsCacheManager::ResetEmbeddingHashMap() {
MS_ERROR_IF_NULL(embedding_device_cache_);
const auto &device_hash_map = embedding_device_cache_->device_hash_map_;
MS_ERROR_IF_NULL(device_hash_map);
MS_ERROR_IF_NULL(embedding_host_cache_);
const auto &host_hash_map = embedding_host_cache_->host_hash_map_;
MS_ERROR_IF_NULL(host_hash_map);
device_hash_map->Reset();
host_hash_map->Reset();
device_need_wait_graph_ = false;
host_need_wait_graph_ = false;
return true;
}
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) {
MS_ERROR_IF_NULL(batch_ids);
MS_ERROR_IF_NULL(hash_index);
@ -463,6 +481,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len,
MS_LOG(EXCEPTION) << "Data in device memset failed.";
}
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get());
RETURN_IF_FALSE(ResetEmbeddingHashMap());
for (size_t i = 0; i < batch_ids_len; i++) {
if (in_device[i]) {
continue;
@ -529,7 +548,7 @@ bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host,
auto tmp_device_to_host_size = statistics_info_.device_to_host_size_;
while (true) {
index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_,
&(statistics_info_.device_to_host_size_));
&(statistics_info_.device_to_host_size_), &device_need_wait_graph_);
if (index == INVALID_INDEX_VALUE) {
if (!WaitGraphRun()) {
return false;
@ -570,8 +589,9 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) {
MS_ERROR_IF_NULL(server_to_host_index);
MS_ERROR_IF_NULL(server_to_host_ids);
while (true) {
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_);
auto index =
host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_,
&statistics_info_.host_to_server_size_, &host_need_wait_graph_);
if (index == INVALID_INDEX_VALUE) {
RETURN_IF_FALSE(WaitGraphRun());
continue;
@ -607,8 +627,9 @@ bool PsCacheManager::ParseHostDataDeviceToHost() {
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get();
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get();
while (true) {
auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids,
data_step_, graph_running_step_, &statistics_info_.host_to_server_size_);
auto index =
host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step_,
graph_running_step_, &statistics_info_.host_to_server_size_, &host_need_wait_graph_);
if (index == INVALID_INDEX_VALUE) {
RETURN_IF_FALSE(WaitGraphRun());
continue;

@ -173,6 +173,7 @@ class PsCacheManager {
bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device,
size_t *hash_hit_count);
bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device);
bool ResetEmbeddingHashMap();
bool initialized_ps_cache_{false};
std::string channel_name_;
std::mutex channel_mutex_;
@ -198,6 +199,8 @@ class PsCacheManager {
std::atomic_bool finish_init_parameter_server_{false};
std::atomic_bool running_{false};
bool finish_embedding_table_sync_{false};
bool device_need_wait_graph_{false};
bool host_need_wait_graph_{false};
};
static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();

Loading…
Cancel
Save