|
|
|
@ -81,11 +81,9 @@ void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
node_ = kernel_node;
|
|
|
|
|
auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
|
|
|
|
|
if (hashmap_shape.size() != 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
hashmap_length_ = hashmap_shape[0];
|
|
|
|
|
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
|
|
|
|
|
}
|
|
|
|
@ -121,7 +119,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr);
|
|
|
|
|
auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr);
|
|
|
|
|
auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr);
|
|
|
|
|
|
|
|
|
|
std::vector<T> miss_idx;
|
|
|
|
|
size_t miss_count = 0;
|
|
|
|
|
float total_count = 0;
|
|
|
|
@ -134,9 +131,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
output_cache_idx[i] = -1;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T tmp_entry = HashFunc(key, hashmap_length_);
|
|
|
|
|
|
|
|
|
|
size_t count = 1;
|
|
|
|
|
count_size += 1;
|
|
|
|
|
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) {
|
|
|
|
@ -147,7 +142,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
}
|
|
|
|
|
count += 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
total_count += count;
|
|
|
|
|
if (hashmap[tmp_entry].IsEmpty()) {
|
|
|
|
|
miss_idx.emplace_back(i);
|
|
|
|
@ -163,10 +157,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
MS_LOG(INFO) << "Miss count: " << miss_count;
|
|
|
|
|
MS_LOG(INFO) << "Avg search count: " << total_count / count_size;
|
|
|
|
|
MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size;
|
|
|
|
|
|
|
|
|
|
float total_insert_count = 0;
|
|
|
|
|
float total_delete_count = 0;
|
|
|
|
|
|
|
|
|
|
// swap hash map
|
|
|
|
|
for (size_t i = 0; i < miss_count; ++i) {
|
|
|
|
|
T emb_idx = output_miss_emb_idx[i];
|
|
|
|
@ -180,11 +172,9 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
}
|
|
|
|
|
tag_count++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
hashmap[entry].key = emb_idx;
|
|
|
|
|
hashmap[entry].step = step_[0];
|
|
|
|
|
hashmap[entry].tag = tag_count;
|
|
|
|
|
|
|
|
|
|
T tmp_entry = (entry + 1) % hashmap_length_;
|
|
|
|
|
size_t delete_count = 1;
|
|
|
|
|
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
|
|
|
|
@ -195,7 +185,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
}
|
|
|
|
|
delete_count++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
|
|
|
|
|
output_old_emb_idx[i] = hashmap[tmp_entry].key;
|
|
|
|
|
hashmap[entry].value = output_swap_cache_idx[i];
|
|
|
|
@ -204,19 +193,15 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
total_delete_count += (compress_count + delete_count);
|
|
|
|
|
total_insert_count += tag_count;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count;
|
|
|
|
|
MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count;
|
|
|
|
|
|
|
|
|
|
// update step
|
|
|
|
|
step_[0] += 1;
|
|
|
|
|
|
|
|
|
|
// update cache idx
|
|
|
|
|
for (size_t i = 0; i < miss_count; ++i) {
|
|
|
|
|
int idx = miss_idx[i];
|
|
|
|
|
output_cache_idx[idx] = output_swap_cache_idx[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<size_t> out_shape;
|
|
|
|
|
out_shape.emplace_back(miss_count);
|
|
|
|
|
std::vector<TypeId> dtypes;
|
|
|
|
|