add dynamic ops

fangzehua 4 years ago
parent 8aa78c2c8e
commit b7d8e87647

@ -35,7 +35,6 @@ void AssignCPUKernel::InitKernel(const CNodePtr &kernel_node) {
input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) {
input_x_dtype_size_ = 4;
} else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) {
@ -75,6 +74,5 @@ void AssignCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
} // namespace kernel
} // namespace mindspore

@ -60,7 +60,6 @@ MS_REG_CPU_KERNEL(
} // namespace kernel
} // namespace mindspore

@ -20,7 +20,6 @@
namespace mindspore {
namespace kernel {
template <typename T>
void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
T i = (entry + 1) % length, off = 1;
@ -107,6 +106,5 @@ void CacheSwapHashmapCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inpu
} // namespace kernel
} // namespace mindspore

@ -25,7 +25,6 @@
namespace mindspore {
namespace kernel {
class CacheSwapHashmapCPUKernel : public CPUKernel {
CacheSwapHashmapCPUKernel() = default;
@ -82,7 +81,6 @@ MS_REG_CPU_KERNEL(CacheSwapHashmap,
} // namespace kernel
} // namespace mindspore

@ -22,7 +22,6 @@
namespace mindspore {
namespace kernel {
template <typename T>
struct HashmapEntry {
T key;
@ -60,8 +59,9 @@ T HashFunc(const T &key, const size_t &m) {
template <typename T>
void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
T i = (entry + 1) % length, off = 1;
int compress_count = 0;
for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) {
if (entry_p[i].tag > off) {
entry_p[entry].key = entry_p[i].key;
@ -72,21 +72,20 @@ void Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
off = 0;
entry = i;
return compress_count;
void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) {
node_ = kernel_node;
auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
if (hashmap_shape.size() != 2) {
MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)";
for (size_t i = 0; i < emb_idx_shape.size(); ++i) {
batch_size_ *= emb_idx_shape[i];
hashmap_length_ = hashmap_shape[0];
dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
@ -108,100 +107,124 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
template <typename T>
void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
batch_size_ = 1;
for (size_t i = 0; i < emb_idx_shape.size(); ++i) {
batch_size_ *= emb_idx_shape[i];
HashmapEntry<T> *hashmap = reinterpret_cast<HashmapEntry<T> *>(inputs[0]->addr);
auto input_indices = reinterpret_cast<T *>(inputs[1]->addr);
T *step_ = reinterpret_cast<T *>(inputs[2]->addr);
T emb_max_num = *reinterpret_cast<T *>(inputs[3]->addr);
T cache_max_num = *reinterpret_cast<T *>(inputs[4]->addr);
T offset = *reinterpret_cast<T *>(inputs[4]->addr);
auto output_cache_idx = reinterpret_cast<T *>(outputs[0]->addr);
auto output_old_emb_idx = reinterpret_cast<T *>(outputs[1]->addr);
auto output_miss_emb_idx = reinterpret_cast<T *>(outputs[2]->addr);
auto output_swap_cache_idx = reinterpret_cast<T *>(outputs[3]->addr);
std::vector<T> output_miss_idx(batch_size_, -1);
std::vector<T> miss_idx;
size_t miss_count = 0;
float total_count = 0;
int count_size = 0;
float hit_count = 0;
// search_cache_idx
for (size_t i = 0; i < batch_size_; ++i) {
if (input_indices[i] == emb_max_num) {
output_miss_idx[i] = -1;
output_cache_idx[i] = cache_max_num;
output_miss_emb_idx[i] = -1;
T key = input_indices[i] - offset;
if (key >= emb_max_num || key < 0) {
output_cache_idx[i] = -1;
T key = input_indices[i];
T tmp_entry = HashFunc(key, hashmap_length_);
int count = 1;
size_t count = 1;
count_size += 1;
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
if (count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!";
count += 1;
total_count += count;
if (hashmap[tmp_entry].IsEmpty()) {
output_miss_idx[i] = i;
output_miss_emb_idx[i] = key;
output_miss_emb_idx[miss_count] = key;
output_cache_idx[i] = -1;
} else {
hit_count += 1;
output_miss_idx[i] = -1;
output_cache_idx[i] = hashmap[tmp_entry].value;
hashmap[tmp_entry].step = step_[0];
output_miss_emb_idx[i] = -1;
MS_LOG(INFO) << "avg search count: " << total_count / count_size;
MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size;
MS_LOG(INFO) << "Miss count: " << miss_count;
MS_LOG(INFO) << "Avg search count: " << total_count / count_size;
MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size;
float total_insert_count = 0;
float total_delete_count = 0;
// swap hash map
for (size_t i = 0; i < batch_size_; ++i) {
if (output_miss_emb_idx[i] < 0) {
output_swap_cache_idx[i] = -1;
output_old_emb_idx[i] = -1;
} else {
T emb_idx = output_miss_emb_idx[i];
T entry = HashFunc(emb_idx, hashmap_length_);
T tag_count = 1;
while (!hashmap[entry].IsEmpty()) {
entry = (entry + 1) % hashmap_length_;
for (size_t i = 0; i < miss_count; ++i) {
T emb_idx = output_miss_emb_idx[i];
T entry = HashFunc(emb_idx, hashmap_length_);
size_t tag_count = 1;
while (!hashmap[entry].IsEmpty()) {
entry = (entry + 1) % hashmap_length_;
if (tag_count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, insert new key failed!";
hashmap[entry].key = emb_idx;
hashmap[entry].step = step_[0];
hashmap[entry].tag = tag_count;
T tmp_entry = (entry + 1) % hashmap_length_;
hashmap[entry].key = emb_idx;
hashmap[entry].step = step_[0];
hashmap[entry].tag = tag_count;
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
T tmp_entry = (entry + 1) % hashmap_length_;
size_t delete_count = 1;
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
if (delete_count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, delete old key failed!";
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
output_old_emb_idx[i] = hashmap[tmp_entry].key;
hashmap[entry].value = output_swap_cache_idx[i];
Compress(hashmap, hashmap_length_, tmp_entry);
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
output_old_emb_idx[i] = hashmap[tmp_entry].key;
hashmap[entry].value = output_swap_cache_idx[i];
int compress_count = Compress(hashmap, hashmap_length_, tmp_entry);
total_delete_count += (compress_count + delete_count);
total_insert_count += tag_count;
MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count;
MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count;
// update step
step_[0] += 1;
// update cache idx
for (size_t i = 0; i < batch_size_; ++i) {
if (output_miss_idx[i] < 0 || output_miss_idx[i] >= cache_max_num) {
output_cache_idx[i] = output_swap_cache_idx[i];
for (size_t i = 0; i < miss_count; ++i) {
int idx = miss_idx[i];
output_cache_idx[idx] = output_swap_cache_idx[i];
std::vector<size_t> out_shape;
std::vector<TypeId> dtypes;
for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) {
dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i));
AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape},
} // namespace kernel
} // namespace mindspore

@ -27,7 +27,6 @@
namespace mindspore {
namespace kernel {
class MapCacheIdxCPUKernel : public CPUKernel {
MapCacheIdxCPUKernel() = default;
@ -45,6 +44,7 @@ class MapCacheIdxCPUKernel : public CPUKernel {
size_t batch_size_{1};
size_t hashmap_length_{1};
TypeId dtype_{kTypeUnknown};
CNodePtr node_ = nullptr;
@ -98,7 +98,6 @@ MS_REG_CPU_KERNEL(MapCacheIdx,
} // namespace kernel
} // namespace mindspore

@ -99,6 +99,5 @@ void SearchCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs
MS_LOG(INFO) << "avg search count: " << total_count / count_size;
MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size;
} // namespace kernel
} // namespace mindspore

@ -27,7 +27,6 @@
namespace mindspore {
namespace kernel {
template <typename T>
struct HashmapEntry {
T key;
@ -133,7 +132,6 @@ MS_REG_CPU_KERNEL(SearchCacheIdx,
} // namespace kernel
} // namespace mindspore

@ -21,20 +21,9 @@
namespace mindspore {
namespace kernel {
void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) {
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
if (indices_shape.size() < 2) {
MS_LOG(EXCEPTION) << "indices shape less than 2";
for (size_t i = 0; i < indices_shape.size(); ++i) {
batch_size_ *= indices_shape[i];
node_ = kernel_node;
for (size_t i = 0; i < update_shape.size(); ++i) {
update_size_ *= update_shape[i];
update_length_ = update_size_ / batch_size_;
input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0);
indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1);
@ -64,6 +53,19 @@ bool UpdateCacheCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
template <typename T>
void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1);
auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2);
batch_size_ = 1;
for (size_t i = 0; i < indices_shape.size(); ++i) {
batch_size_ *= indices_shape[i];
MS_LOG(INFO) << "UpdateCache batch_size:" << batch_size_;
update_size_ = 1;
for (size_t i = 0; i < update_shape.size(); ++i) {
update_size_ *= update_shape[i];
update_length_ = update_shape[1];
char *input_x = reinterpret_cast<char *>(inputs[0]->addr);
T *indices = reinterpret_cast<T *>(inputs[1]->addr);
char *update = reinterpret_cast<char *>(inputs[2]->addr);
@ -80,6 +82,5 @@ void UpdateCacheCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
} // namespace kernel
} // namespace mindspore

@ -46,6 +46,7 @@ class UpdateCacheCPUKernel : public CPUKernel {
TypeId input_x_dtype_{kTypeUnknown};
TypeId indices_dtype_{kTypeUnknown};
size_t input_x_dtype_size_ = 4;
CNodePtr node_ = nullptr;
@ -101,7 +102,6 @@ MS_REG_CPU_KERNEL(UpdateCache,
} // namespace kernel
} // namespace mindspore

@ -201,7 +201,12 @@ AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &prim
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

@ -254,6 +254,99 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv
return std::make_shared<AbstractTensor>(x->element(), x->shape());
AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 5);
auto hash_map = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto indices_shp = indices->shape();
ShapeVector shape;
ShapeVector min_shape;
ShapeVector max_shape;
if (!indices_shp->max_shape().empty()) {
max_shape = indices_shp->max_shape();
} else {
max_shape = indices_shp->shape();
for (size_t i = 0; i < max_shape.size(); i++) {
auto cache_idx = std::make_shared<AbstractTensor>(hash_map->element(), indices->shape());
auto old_emb_idx =
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
auto miss_emb_idx =
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
auto swap_emb_idx =
std::make_shared<AbstractTensor>(hash_map->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx};
return std::make_shared<AbstractTuple>(elements);
AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
auto cache_table = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto cache_table_shp = cache_table->shape();
auto swap_cache_idx = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto swap_cache_idx_shp = swap_cache_idx->shape();
auto cache_table_shape = cache_table_shp->shape();
auto swap_cache_idx_shape = swap_cache_idx_shp->shape();
ShapeVector shape;
auto swap_cache_idx_max_shape = swap_cache_idx_shp->max_shape();
ShapeVector max_shape;
ShapeVector min_shape;
if (!swap_cache_idx_max_shape.empty()) {
} else {
max_shape = shape;
for (size_t i = 0; i < max_shape.size(); ++i) {
AbstractTensorPtr ret =
std::make_shared<AbstractTensor>(cache_table->element(), std::make_shared<Shape>(shape, min_shape, max_shape));
return ret;
AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
auto input_x = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
ShapeVector shape;
AbstractTensorPtr ret = std::make_shared<AbstractTensor>(input_x->element(), std::make_shared<Shape>(shape));
return ret;
AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();

@ -56,6 +56,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}},
{prim::kPrimScatterAdd, {InferImplScatterAdd, true}},
{prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}},
{prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}},
{prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}},
{prim::kPrimUpdateCache, {InferImplUpdateCache, true}},
{prim::kPrimDiv, {InferImplDiv, true}},
{prim::kPrimRealDiv, {InferImplRealDiv, true}},
{prim::kPrimShape, {InferImplShape, false}},

@ -98,6 +98,9 @@ inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared<Primitive>(
inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared<Primitive>("UnsortedSegmentMin");
inline const PrimitivePtr kPrimConcatOffset = std::make_shared<Primitive>("ConcatOffset");
inline const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared<Primitive>("MapCacheIdx");
inline const PrimitivePtr kPrimUpdateCache = std::make_shared<Primitive>("UpdateCache");
inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared<Primitive>("CacheSwapTable");
inline const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
inline const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared<Primitive>("AccumulateNV2");

@ -15,11 +15,11 @@
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register
from ..primitive import PrimitiveWithInfer, prim_attr_register, PrimitiveWithCheck
from .. import signature as sig
class UpdateCache(PrimitiveWithInfer):
class UpdateCache(PrimitiveWithCheck):
Update the value fo input_x, similar to ScatterNdUpdate.
The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num.
@ -47,15 +47,12 @@ class UpdateCache(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
def infer_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
if len(indices_shape) < 2:
raise ValueError("The dimension of 'indices' in UpdateCache must >= 2, "
"but got %d." % len(indices_shape))
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
return [1]
def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type,
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
"indices", indices_dtype, mstype.int_type,
return input_x_dtype
@ -139,7 +136,8 @@ class SearchCacheIdx(PrimitiveWithInfer):
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type,
args, mstype.int_type,
out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
return out_dtype
@ -172,7 +170,6 @@ class CacheSwapHashmap(PrimitiveWithInfer):
outputs=['swap_cache_idx', 'old_emb_idx'])
def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape):
if len(hashmap_shape) != 2:
raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, "
"but got %d." % len(hashmap_shape))
@ -181,12 +178,13 @@ class CacheSwapHashmap(PrimitiveWithInfer):
return out_shape
def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type,
"miss_emb_idx", miss_emb_idx_dtype, mstype.int_type,
out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
return out_dtype
class CacheSwapTable(PrimitiveWithInfer):
class CacheSwapTable(PrimitiveWithCheck):
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
@ -212,21 +210,20 @@ class CacheSwapTable(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
def infer_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
if len(cache_table_shape) != 2:
raise ValueError(
"cache table shape must be 2, but got %d" % len(cache_table_shape))
if swap_cache_idx_shape + cache_table_shape[1:] != miss_value_shape:
raise ValueError(
"swap_cache_idx_shape + cache_table_shape[1:] must equal to miss_value_shape")
return miss_value_shape
def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type,
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type,
return miss_value_dtype
class MapCacheIdx(PrimitiveWithInfer):
class MapCacheIdx(PrimitiveWithCheck):
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
When input an indices tensor, it will output the cache indices which search in hashmap.
@ -244,21 +241,34 @@ class MapCacheIdx(PrimitiveWithInfer):
def __init__(self):
"""init MapCacheIdx"""
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'],
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape):
def __check__(self, hashmap, indices, step, emb_max_num, offset):
hashmap_shape = hashmap['shape']
if len(hashmap_shape) != 2:
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
"but got %d." % len(hashmap_shape))
out_shape = (indices_shape, indices_shape,
indices_shape, indices_shape)
return out_shape
out_shape = (indices['shape'], -1, -1, -1)
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
hashmap_dtype = hashmap['dtype']
indices_dtype = indices['dtype']
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type,
validator.check_tensor_type_same(args, mstype.int_type,
out_dtype = (hashmap_dtype, hashmap_dtype,
hashmap_dtype, hashmap_dtype)
return out_dtype
out = {'shape': out_shape,
'dtype': out_dtype,
'value': None}
if 'max_shape' in indices:
out['max_shape'] = (indices['max_shape'], indices['max_shape'],
indices['max_shape'], indices['max_shape'])
out['max_shape'] = (indices['shape'], indices['shape'],
indices['shape'], indices['shape'])
if 'min_shape' in indices:
out['min_shape'] = (indices['min_shape'], 0, 0, 0)
out['min_shape'] = (0, 0, 0, 0)
return out

@ -75,19 +75,6 @@ class CacheSwapHashmapNet(nn.Cell):
return self.ops(, miss_emb_idx, self.step)
class MapCacheIdxNet(nn.Cell):
def __init__(self, hashmap_np):
self.ops = P.MapCacheIdx()
self.hashmap = Parameter(Tensor(hashmap_np), name="hashmap")
self.emb_max = 25
self.cache_max = 10
self.step = 0
def construct(self, indices):
return self.ops(self.hashmap, indices, self.step, self.emb_max, self.cache_max)
class UpdateCacheNet(nn.Cell):
def __init__(self, x):
@ -165,45 +152,6 @@ def test_cache_swap_hashmap():
np.array(hashmap_np_after_ops, np.int32))
def test_map_cache_idx():
hashmap_np = init_hashmap(10)
indices_np = np.array([10, 2, 20, 5, 3], np.int32)
map_cache_idx = MapCacheIdxNet(hashmap_np)
indices = Tensor(indices_np)
cache_idx, old_emb_idx, miss_emb_idx, swap_cache_idx = map_cache_idx(
expect_cache_idx = [5, 1, 9, 7, 3]
expect_old_emb_idx = [-1, -1, 21, 15, -1]
expect_miss_emb_idx = [-1, -1, 20, 5, -1]
expect_swap_cache_idx = [-1, -1, 9, 7, -1]
hashmap_np_after_ops = [[5, 7, 0, 1],
[10, 5, 0, 1],
[2, 1, 0, 1],
[20, 9, 0, 1],
[20, 9, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[3, 3, 0, 1],
[21, 9, -5, 0]]
assert np.allclose(cache_idx.asnumpy(),
np.array(expect_cache_idx, np.int32))
assert np.allclose(old_emb_idx.asnumpy(),
np.array(expect_old_emb_idx, np.int32))
assert np.allclose(miss_emb_idx.asnumpy(),
np.array(expect_miss_emb_idx, np.int32))
assert np.allclose(swap_cache_idx.asnumpy(),
np.array(expect_swap_cache_idx, np.int32))
assert np.allclose(,
np.array(hashmap_np_after_ops, np.int32))
