|
|
|
@ -15,11 +15,11 @@
|
|
|
|
|
"""cache_ops"""
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
|
|
|
|
from ..primitive import PrimitiveWithInfer, prim_attr_register, PrimitiveWithCheck
|
|
|
|
|
from .. import signature as sig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UpdateCache(PrimitiveWithInfer):
|
|
|
|
|
class UpdateCache(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Update the value fo input_x, similar to ScatterNdUpdate.
|
|
|
|
|
The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num.
|
|
|
|
@ -47,15 +47,12 @@ class UpdateCache(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'],
|
|
|
|
|
outputs=['out'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
|
|
|
|
|
|
|
|
|
|
if len(indices_shape) < 2:
|
|
|
|
|
raise ValueError("The dimension of 'indices' in UpdateCache must >= 2, "
|
|
|
|
|
"but got %d." % len(indices_shape))
|
|
|
|
|
def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape):
|
|
|
|
|
return [1]
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
|
|
|
|
|
validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name)
|
|
|
|
|
def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype):
|
|
|
|
|
validator.check_tensor_dtype_valid(
|
|
|
|
|
"indices", indices_dtype, mstype.int_type, self.name)
|
|
|
|
|
return input_x_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -139,7 +136,8 @@ class SearchCacheIdx(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
|
|
|
|
|
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
|
|
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
|
|
|
|
|
validator.check_tensors_dtypes_same_and_valid(
|
|
|
|
|
args, mstype.int_type, self.name)
|
|
|
|
|
out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype)
|
|
|
|
|
return out_dtype
|
|
|
|
|
|
|
|
|
@ -172,7 +170,6 @@ class CacheSwapHashmap(PrimitiveWithInfer):
|
|
|
|
|
outputs=['swap_cache_idx', 'old_emb_idx'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape):
|
|
|
|
|
|
|
|
|
|
if len(hashmap_shape) != 2:
|
|
|
|
|
raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, "
|
|
|
|
|
"but got %d." % len(hashmap_shape))
|
|
|
|
@ -181,12 +178,13 @@ class CacheSwapHashmap(PrimitiveWithInfer):
|
|
|
|
|
return out_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype):
|
|
|
|
|
validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
|
|
|
|
|
validator.check_tensor_dtype_valid(
|
|
|
|
|
"miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name)
|
|
|
|
|
out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype)
|
|
|
|
|
return out_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CacheSwapTable(PrimitiveWithInfer):
|
|
|
|
|
class CacheSwapTable(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry.
|
|
|
|
|
|
|
|
|
@ -212,21 +210,20 @@ class CacheSwapTable(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'],
|
|
|
|
|
outputs=['old_value'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
|
|
|
|
|
def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape):
|
|
|
|
|
if len(cache_table_shape) != 2:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"cache table shape must be 2, but got %d" % len(cache_table_shape))
|
|
|
|
|
if swap_cache_idx_shape + cache_table_shape[1:] != miss_value_shape:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"swap_cache_idx_shape + cache_table_shape[1:] must equal to miss_value_shape")
|
|
|
|
|
|
|
|
|
|
return miss_value_shape
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
|
|
|
|
|
validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
|
|
|
|
|
def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype):
|
|
|
|
|
validator.check_tensor_dtype_valid(
|
|
|
|
|
"swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name)
|
|
|
|
|
return miss_value_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MapCacheIdx(PrimitiveWithInfer):
|
|
|
|
|
class MapCacheIdx(PrimitiveWithCheck):
|
|
|
|
|
"""
|
|
|
|
|
MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together.
|
|
|
|
|
When input an indices tensor, it will output the cache indices which search in hashmap.
|
|
|
|
@ -244,21 +241,34 @@ class MapCacheIdx(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""init MapCacheIdx"""
|
|
|
|
|
|
|
|
|
|
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'],
|
|
|
|
|
self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'],
|
|
|
|
|
outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx'])
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape):
|
|
|
|
|
|
|
|
|
|
def __check__(self, hashmap, indices, step, emb_max_num, offset):
|
|
|
|
|
hashmap_shape = hashmap['shape']
|
|
|
|
|
if len(hashmap_shape) != 2:
|
|
|
|
|
raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, "
|
|
|
|
|
"but got %d." % len(hashmap_shape))
|
|
|
|
|
out_shape = (indices_shape, indices_shape,
|
|
|
|
|
indices_shape, indices_shape)
|
|
|
|
|
return out_shape
|
|
|
|
|
out_shape = (indices['shape'], -1, -1, -1)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype):
|
|
|
|
|
hashmap_dtype = hashmap['dtype']
|
|
|
|
|
indices_dtype = indices['dtype']
|
|
|
|
|
args = {"hashmap": hashmap_dtype, "indices": indices_dtype}
|
|
|
|
|
validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name)
|
|
|
|
|
validator.check_tensor_type_same(args, mstype.int_type, self.name)
|
|
|
|
|
out_dtype = (hashmap_dtype, hashmap_dtype,
|
|
|
|
|
hashmap_dtype, hashmap_dtype)
|
|
|
|
|
return out_dtype
|
|
|
|
|
|
|
|
|
|
out = {'shape': out_shape,
|
|
|
|
|
'dtype': out_dtype,
|
|
|
|
|
'value': None}
|
|
|
|
|
if 'max_shape' in indices:
|
|
|
|
|
out['max_shape'] = (indices['max_shape'], indices['max_shape'],
|
|
|
|
|
indices['max_shape'], indices['max_shape'])
|
|
|
|
|
else:
|
|
|
|
|
out['max_shape'] = (indices['shape'], indices['shape'],
|
|
|
|
|
indices['shape'], indices['shape'])
|
|
|
|
|
if 'min_shape' in indices:
|
|
|
|
|
out['min_shape'] = (indices['min_shape'], 0, 0, 0)
|
|
|
|
|
else:
|
|
|
|
|
out['min_shape'] = (0, 0, 0, 0)
|
|
|
|
|
return out
|
|
|
|
|