|
|
|
@ -42,9 +42,21 @@ int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
|
|
|
|
|
return compress_count;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void UpdateShape(size_t miss_count, const CNodePtr &node_) {
|
|
|
|
|
std::vector<size_t> out_shape;
|
|
|
|
|
out_shape.emplace_back(miss_count);
|
|
|
|
|
std::vector<TypeId> dtypes;
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
|
|
|
|
|
for (size_t i = 0; i < output_num; i++) {
|
|
|
|
|
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape},
|
|
|
|
|
node_.get());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
|
node_ = kernel_node;
|
|
|
|
|
node_wpt_ = kernel_node;
|
|
|
|
|
auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
if (hashmap_shape.size() != 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)";
|
|
|
|
@ -73,6 +85,7 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
template <typename T>
|
|
|
|
|
void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
|
auto node_ = node_wpt_.lock();
|
|
|
|
|
auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
|
|
|
|
|
batch_size_ = 1;
|
|
|
|
|
for (size_t i = 0; i < emb_idx_shape.size(); ++i) {
|
|
|
|
@ -92,7 +105,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
float total_count = 0;
|
|
|
|
|
int count_size = 0;
|
|
|
|
|
float hit_count = 0;
|
|
|
|
|
|
|
|
|
|
// search_cache_idx
|
|
|
|
|
for (size_t i = 0; i < batch_size_; ++i) {
|
|
|
|
|
T key = input_indices[i] - offset;
|
|
|
|
@ -107,7 +119,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
tmp_entry = (tmp_entry + 1) % hashmap_length_;
|
|
|
|
|
if (count > hashmap_length_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
count += 1;
|
|
|
|
|
}
|
|
|
|
@ -130,7 +141,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
MS_LOG(INFO) << "Avg search count: " << total_count / count_size;
|
|
|
|
|
MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float total_insert_count = 0;
|
|
|
|
|
float total_delete_count = 0;
|
|
|
|
|
// swap hash map
|
|
|
|
@ -142,7 +152,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
entry = (entry + 1) % hashmap_length_;
|
|
|
|
|
if (tag_count > hashmap_length_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
tag_count++;
|
|
|
|
|
}
|
|
|
|
@ -155,7 +164,6 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
tmp_entry = (tmp_entry + 1) % hashmap_length_;
|
|
|
|
|
if (delete_count > hashmap_length_) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
delete_count++;
|
|
|
|
|
}
|
|
|
|
@ -171,22 +179,11 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
|
|
|
|
MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count;
|
|
|
|
|
MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count;
|
|
|
|
|
}
|
|
|
|
|
// update step
|
|
|
|
|
step_[0] += 1;
|
|
|
|
|
// update cache idx
|
|
|
|
|
for (size_t i = 0; i < miss_count; ++i) {
|
|
|
|
|
int idx = miss_idx[i];
|
|
|
|
|
output_cache_idx[idx] = output_swap_cache_idx[i];
|
|
|
|
|
output_cache_idx[miss_idx[i]] = output_swap_cache_idx[i];
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> out_shape;
|
|
|
|
|
out_shape.emplace_back(miss_count);
|
|
|
|
|
std::vector<TypeId> dtypes;
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(node_);
|
|
|
|
|
for (size_t i = 0; i < output_num; i++) {
|
|
|
|
|
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape},
|
|
|
|
|
node_.get());
|
|
|
|
|
UpdateShape(miss_count, node_);
|
|
|
|
|
}
|
|
|
|
|
} // namespace kernel
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|