diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc index 3760f1b5a1..ef12f985a6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.cc @@ -35,7 +35,6 @@ void AssignCPUKernel::InitKernel(const CNodePtr &kernel_node) { } } input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); - if (input_x_dtype_ == kNumberTypeFloat32 || input_x_dtype_ == kNumberTypeInt32) { input_x_dtype_size_ = 4; } else if (input_x_dtype_ == kNumberTypeFloat64 || input_x_dtype_ == kNumberTypeInt64) { @@ -75,6 +74,5 @@ void AssignCPUKernel::LaunchKernel(const std::vector &inputs, MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; } } - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h index 9c7d5d086a..b58b4f1756 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/assign_cpu_kernel.h @@ -60,7 +60,6 @@ MS_REG_CPU_KERNEL( Assign, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), AssignCPUKernel); - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.cc index 2245626df7..14dd556658 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.cc @@ -20,7 +20,6 @@ namespace mindspore { namespace kernel { - template void Compress(HashmapEntry *entry_p, const size_t &length, T entry) { T i = (entry + 1) % length, off = 1; @@ -107,6 +106,5 @@ void CacheSwapHashmapCPUKernel::LaunchKernel(const std::vector &inpu } } } - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.h index fcbcf1265a..d92acf6b05 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cache_swap_hashmap_cpu_kernel.h @@ -25,7 +25,6 @@ namespace mindspore { namespace kernel { - class CacheSwapHashmapCPUKernel : public CPUKernel { public: CacheSwapHashmapCPUKernel() = default; @@ -82,7 +81,6 @@ MS_REG_CPU_KERNEL(CacheSwapHashmap, .AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32), CacheSwapHashmapCPUKernel); - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc index 14c787ac57..f0a962e07f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc @@ -22,7 +22,6 @@ namespace mindspore { namespace kernel { - template struct HashmapEntry { T key; @@ -60,8 +59,9 @@ T HashFunc(const T &key, const size_t &m) { } template -void Compress(HashmapEntry *entry_p, const size_t &length, T entry) { +int Compress(HashmapEntry *entry_p, const size_t &length, T entry) { T i = (entry + 1) % length, off = 1; + int compress_count = 0; for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) { if (entry_p[i].tag > off) { entry_p[entry].key = entry_p[i].key; @@ -72,21 +72,20 @@ void Compress(HashmapEntry *entry_p, const size_t &length, T entry) { off = 0; entry = i; } + compress_count++; } + return compress_count; } void MapCacheIdxCPUKernel::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + node_ = kernel_node; auto hashmap_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); if (hashmap_shape.size() != 2) { MS_LOG(EXCEPTION) << "Dimension of HashMap must be 2, (n, 4)"; } - for (size_t i = 0; i < emb_idx_shape.size(); ++i) { - batch_size_ *= emb_idx_shape[i]; - } - hashmap_length_ = hashmap_shape[0]; dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); } @@ -108,100 +107,124 @@ bool MapCacheIdxCPUKernel::Launch(const std::vector &inputs, template void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto emb_idx_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); + batch_size_ = 1; + for (size_t i = 0; i < emb_idx_shape.size(); ++i) { + batch_size_ *= emb_idx_shape[i]; + } HashmapEntry *hashmap = reinterpret_cast *>(inputs[0]->addr); auto input_indices = reinterpret_cast(inputs[1]->addr); T *step_ = reinterpret_cast(inputs[2]->addr); T emb_max_num = *reinterpret_cast(inputs[3]->addr); - T cache_max_num = *reinterpret_cast(inputs[4]->addr); + T offset = *reinterpret_cast(inputs[4]->addr); auto output_cache_idx = reinterpret_cast(outputs[0]->addr); auto output_old_emb_idx = reinterpret_cast(outputs[1]->addr); auto output_miss_emb_idx = reinterpret_cast(outputs[2]->addr); auto output_swap_cache_idx = reinterpret_cast(outputs[3]->addr); - std::vector output_miss_idx(batch_size_, -1); - + std::vector miss_idx; + size_t miss_count = 0; float total_count = 0; int count_size = 0; float hit_count = 0; - // search_cache_idx for (size_t i = 0; i < batch_size_; ++i) { - if (input_indices[i] == emb_max_num) { - output_miss_idx[i] = -1; - output_cache_idx[i] = cache_max_num; - output_miss_emb_idx[i] = -1; + T key = input_indices[i] - offset; + if (key >= emb_max_num || key < 0) { + output_cache_idx[i] = -1; continue; } - T key = input_indices[i]; T tmp_entry = HashFunc(key, hashmap_length_); - int count = 1; + size_t count = 1; count_size += 1; while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) { tmp_entry = (tmp_entry + 1) % hashmap_length_; + if (count > hashmap_length_) { + MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!"; + break; + } count += 1; } total_count += count; if (hashmap[tmp_entry].IsEmpty()) { - output_miss_idx[i] = i; - output_miss_emb_idx[i] = key; + miss_idx.emplace_back(i); + output_miss_emb_idx[miss_count] = key; output_cache_idx[i] = -1; + miss_count++; } else { hit_count += 1; - output_miss_idx[i] = -1; output_cache_idx[i] = hashmap[tmp_entry].value; hashmap[tmp_entry].step = step_[0]; - output_miss_emb_idx[i] = -1; } } - MS_LOG(INFO) << "avg search count: " << total_count / count_size; - MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size; + MS_LOG(INFO) << "Miss count: " << miss_count; + MS_LOG(INFO) << "Avg search count: " << total_count / count_size; + MS_LOG(INFO) << "Cache hit rate: " << hit_count / count_size; + + float total_insert_count = 0; + float total_delete_count = 0; // swap hash map - for (size_t i = 0; i < batch_size_; ++i) { - if (output_miss_emb_idx[i] < 0) { - output_swap_cache_idx[i] = -1; - output_old_emb_idx[i] = -1; - } else { - T emb_idx = output_miss_emb_idx[i]; - T entry = HashFunc(emb_idx, hashmap_length_); - T tag_count = 1; - while (!hashmap[entry].IsEmpty()) { - entry = (entry + 1) % hashmap_length_; - tag_count++; + for (size_t i = 0; i < miss_count; ++i) { + T emb_idx = output_miss_emb_idx[i]; + T entry = HashFunc(emb_idx, hashmap_length_); + size_t tag_count = 1; + while (!hashmap[entry].IsEmpty()) { + entry = (entry + 1) % hashmap_length_; + if (tag_count > hashmap_length_) { + MS_LOG(ERROR) << "Hashmap is full, insert new key failed!"; + break; } + tag_count++; + } - hashmap[entry].key = emb_idx; - hashmap[entry].step = step_[0]; - hashmap[entry].tag = tag_count; - - T tmp_entry = (entry + 1) % hashmap_length_; + hashmap[entry].key = emb_idx; + hashmap[entry].step = step_[0]; + hashmap[entry].tag = tag_count; - while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { - tmp_entry = (tmp_entry + 1) % hashmap_length_; + T tmp_entry = (entry + 1) % hashmap_length_; + size_t delete_count = 1; + while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) { + tmp_entry = (tmp_entry + 1) % hashmap_length_; + if (delete_count > hashmap_length_) { + MS_LOG(ERROR) << "Hashmap is full, delete old key failed!"; + break; } - - output_swap_cache_idx[i] = hashmap[tmp_entry].value; - output_old_emb_idx[i] = hashmap[tmp_entry].key; - hashmap[entry].value = output_swap_cache_idx[i]; - hashmap[tmp_entry].SetEmpty(); - Compress(hashmap, hashmap_length_, tmp_entry); + delete_count++; } + + output_swap_cache_idx[i] = hashmap[tmp_entry].value; + output_old_emb_idx[i] = hashmap[tmp_entry].key; + hashmap[entry].value = output_swap_cache_idx[i]; + hashmap[tmp_entry].SetEmpty(); + int compress_count = Compress(hashmap, hashmap_length_, tmp_entry); + total_delete_count += (compress_count + delete_count); + total_insert_count += tag_count; } + MS_LOG(INFO) << "Insert count: " << total_insert_count / miss_count; + MS_LOG(INFO) << "Delete count: " << total_delete_count / miss_count; + // update step step_[0] += 1; // update cache idx - for (size_t i = 0; i < batch_size_; ++i) { - if (output_miss_idx[i] < 0 || output_miss_idx[i] >= cache_max_num) { - continue; - } - output_cache_idx[i] = output_swap_cache_idx[i]; + for (size_t i = 0; i < miss_count; ++i) { + int idx = miss_idx[i]; + output_cache_idx[idx] = output_swap_cache_idx[i]; } -} + std::vector out_shape; + out_shape.emplace_back(miss_count); + std::vector dtypes; + for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) { + dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); + } + AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, + node_.get()); +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h index e2c3d1f4c3..2b3e9e4728 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.h @@ -27,7 +27,6 @@ namespace mindspore { namespace kernel { - class MapCacheIdxCPUKernel : public CPUKernel { public: MapCacheIdxCPUKernel() = default; @@ -45,6 +44,7 @@ class MapCacheIdxCPUKernel : public CPUKernel { size_t batch_size_{1}; size_t hashmap_length_{1}; TypeId dtype_{kTypeUnknown}; + CNodePtr node_ = nullptr; }; MS_REG_CPU_KERNEL(MapCacheIdx, @@ -98,7 +98,6 @@ MS_REG_CPU_KERNEL(MapCacheIdx, .AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32), MapCacheIdxCPUKernel); - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.cc index 58be98b886..ec30a49f3f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.cc @@ -99,6 +99,5 @@ void SearchCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs MS_LOG(INFO) << "avg search count: " << total_count / count_size; MS_LOG(INFO) << "cache hit rate: " << hit_count / count_size; } - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.h index 20bf3a4a66..5441340424 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/search_cache_idx_cpu_kernel.h @@ -27,7 +27,6 @@ namespace mindspore { namespace kernel { - template struct HashmapEntry { T key; @@ -133,7 +132,6 @@ MS_REG_CPU_KERNEL(SearchCacheIdx, .AddOutputAttr(kNumberTypeInt32) .AddOutputAttr(kNumberTypeInt32), SearchCacheIdxCPUKernel); - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc index 0a127d1cf5..14f3b4762b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.cc @@ -21,20 +21,9 @@ namespace mindspore { namespace kernel { void UpdateCacheCPUKernel::InitKernel(const CNodePtr &kernel_node) { - auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); - if (indices_shape.size() < 2) { - MS_LOG(EXCEPTION) << "indices shape less than 2"; - } - - for (size_t i = 0; i < indices_shape.size(); ++i) { - batch_size_ *= indices_shape[i]; - } + MS_EXCEPTION_IF_NULL(kernel_node); + node_ = kernel_node; - for (size_t i = 0; i < update_shape.size(); ++i) { - update_size_ *= update_shape[i]; - } - update_length_ = update_size_ / batch_size_; input_x_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0); indices_dtype_ = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 1); @@ -64,6 +53,19 @@ bool UpdateCacheCPUKernel::Launch(const std::vector &inputs, template void UpdateCacheCPUKernel::LaunchKernel(const std::vector &inputs, const std::vector &outputs) { + auto indices_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 1); + auto update_shape = AnfAlgo::GetPrevNodeOutputInferShape(node_, 2); + + batch_size_ = 1; + for (size_t i = 0; i < indices_shape.size(); ++i) { + batch_size_ *= indices_shape[i]; + } + MS_LOG(INFO) << "UpdateCache batch_size:" << batch_size_; + update_size_ = 1; + for (size_t i = 0; i < update_shape.size(); ++i) { + update_size_ *= update_shape[i]; + } + update_length_ = update_shape[1]; char *input_x = reinterpret_cast(inputs[0]->addr); T *indices = reinterpret_cast(inputs[1]->addr); char *update = reinterpret_cast(inputs[2]->addr); @@ -80,6 +82,5 @@ void UpdateCacheCPUKernel::LaunchKernel(const std::vector &inputs, } } } - } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h index 67309f5b4b..553535dfe8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/update_cache_cpu_kernel.h @@ -46,6 +46,7 @@ class UpdateCacheCPUKernel : public CPUKernel { TypeId input_x_dtype_{kTypeUnknown}; TypeId indices_dtype_{kTypeUnknown}; size_t input_x_dtype_size_ = 4; + CNodePtr node_ = nullptr; }; MS_REG_CPU_KERNEL(UpdateCache, @@ -101,7 +102,6 @@ MS_REG_CPU_KERNEL(UpdateCache, .AddInputAttr(kNumberTypeInt64) .AddOutputAttr(kNumberTypeInt64), UpdateCacheCPUKernel); - } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 92e2d36c84..7c9e5550d4 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -201,7 +201,12 @@ AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &prim const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRealDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); - +AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplGatherV2(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 8ddf36ac75..ad8fcdf0d6 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -273,6 +273,99 @@ AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const Primitiv return std::make_shared(x->element(), x->shape()); } +AbstractBasePtr InferImplMapCacheIdx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 5); + auto hash_map = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(hash_map); + MS_EXCEPTION_IF_NULL(hash_map->shape()); + + auto indices = CheckArg(op_name, args_spec_list, 1); + auto indices_shp = indices->shape(); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(indices_shp); + + ShapeVector shape; + ShapeVector min_shape; + ShapeVector max_shape; + if (!indices_shp->max_shape().empty()) { + max_shape = indices_shp->max_shape(); + } else { + max_shape = indices_shp->shape(); + } + for (size_t i = 0; i < max_shape.size(); i++) { + shape.emplace_back(Shape::SHP_ANY); + min_shape.emplace_back(1); + } + + auto cache_idx = std::make_shared(hash_map->element(), indices->shape()); + auto old_emb_idx = + std::make_shared(hash_map->element(), std::make_shared(shape, min_shape, max_shape)); + auto miss_emb_idx = + std::make_shared(hash_map->element(), std::make_shared(shape, min_shape, max_shape)); + auto swap_emb_idx = + std::make_shared(hash_map->element(), std::make_shared(shape, min_shape, max_shape)); + + AbstractBasePtrList elements = {cache_idx, old_emb_idx, miss_emb_idx, swap_emb_idx}; + return std::make_shared(elements); +} + +AbstractBasePtr InferImplCacheSwapTable(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 3); + auto cache_table = CheckArg(op_name, args_spec_list, 0); + auto cache_table_shp = cache_table->shape(); + MS_EXCEPTION_IF_NULL(cache_table); + MS_EXCEPTION_IF_NULL(cache_table_shp); + + auto swap_cache_idx = CheckArg(op_name, args_spec_list, 1); + auto swap_cache_idx_shp = swap_cache_idx->shape(); + MS_EXCEPTION_IF_NULL(swap_cache_idx); + MS_EXCEPTION_IF_NULL(swap_cache_idx_shp); + + auto cache_table_shape = cache_table_shp->shape(); + auto swap_cache_idx_shape = swap_cache_idx_shp->shape(); + ShapeVector shape; + shape.emplace_back(swap_cache_idx_shape[0]); + shape.emplace_back(cache_table_shape[1]); + auto swap_cache_idx_max_shape = swap_cache_idx_shp->max_shape(); + ShapeVector max_shape; + ShapeVector min_shape; + if (!swap_cache_idx_max_shape.empty()) { + max_shape.emplace_back(swap_cache_idx_max_shape[0]); + max_shape.emplace_back(cache_table_shape[1]); + } else { + max_shape = shape; + } + for (size_t i = 0; i < max_shape.size(); ++i) { + min_shape.emplace_back(1); + } + + AbstractTensorPtr ret = + std::make_shared(cache_table->element(), std::make_shared(shape, min_shape, max_shape)); + return ret; +} + +AbstractBasePtr InferImplUpdateCache(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + auto input_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_x->shape()); + + auto indices = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(indices); + MS_EXCEPTION_IF_NULL(indices->shape()); + + ShapeVector shape; + shape.emplace_back(1); + + AbstractTensorPtr ret = std::make_shared(input_x->element(), std::make_shared(shape)); + return ret; +} + AbstractBasePtr InferImplDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index cdf97fec99..9ffcbca521 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -57,6 +57,9 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimUnsortedSegmentSum, {InferImplUnsortedSegmentSum, true}}, {prim::kPrimScatterAdd, {InferImplScatterAdd, true}}, {prim::kPrimScatterUpdate, {InferImplScatterUpdate, true}}, + {prim::kPrimMapCacheIdx, {InferImplMapCacheIdx, true}}, + {prim::kPrimCacheSwapTable, {InferImplCacheSwapTable, true}}, + {prim::kPrimUpdateCache, {InferImplUpdateCache, true}}, {prim::kPrimDiv, {InferImplDiv, true}}, {prim::kPrimRealDiv, {InferImplRealDiv, true}}, {prim::kPrimShape, {InferImplShape, false}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 49dca44f91..a9f05a023d 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -98,6 +98,9 @@ inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared( inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); inline const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); inline const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); +inline const PrimitivePtr kPrimMapCacheIdx = std::make_shared("MapCacheIdx"); +inline const PrimitivePtr kPrimUpdateCache = std::make_shared("UpdateCache"); +inline const PrimitivePtr kPrimCacheSwapTable = std::make_shared("CacheSwapTable"); inline const PrimitivePtr kPrimTile = std::make_shared("Tile"); inline const PrimitivePtr kPrimAddN = std::make_shared("AddN"); inline const PrimitivePtr kPrimAccumulateNV2 = std::make_shared("AccumulateNV2"); diff --git a/mindspore/ops/operations/_cache_ops.py b/mindspore/ops/operations/_cache_ops.py index fa9e99c53e..5e36031fe0 100644 --- a/mindspore/ops/operations/_cache_ops.py +++ b/mindspore/ops/operations/_cache_ops.py @@ -15,11 +15,11 @@ """cache_ops""" from ..._checkparam import Validator as validator from ...common import dtype as mstype -from ..primitive import PrimitiveWithInfer, prim_attr_register +from ..primitive import PrimitiveWithInfer, prim_attr_register, PrimitiveWithCheck from .. import signature as sig -class UpdateCache(PrimitiveWithInfer): +class UpdateCache(PrimitiveWithCheck): """ Update the value fo input_x, similar to ScatterNdUpdate. The diffirent is that UpdateCache will not update when indices < 0 or indices >= max_num. @@ -47,15 +47,12 @@ class UpdateCache(PrimitiveWithInfer): self.init_prim_io_names(inputs=['input_x', 'indices', 'update', 'max_num'], outputs=['out']) - def infer_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): - - if len(indices_shape) < 2: - raise ValueError("The dimension of 'indices' in UpdateCache must >= 2, " - "but got %d." % len(indices_shape)) + def check_shape(self, input_x_shape, indices_shape, update_shape, max_num_shape): return [1] - def infer_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): - validator.check_tensor_dtype_valid("indices", indices_dtype, mstype.int_type, self.name) + def check_dtype(self, input_x_dtype, indices_dtype, update_dtype, max_num_dtype): + validator.check_tensor_dtype_valid( + "indices", indices_dtype, mstype.int_type, self.name) return input_x_dtype @@ -139,7 +136,8 @@ class SearchCacheIdx(PrimitiveWithInfer): def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): args = {"hashmap": hashmap_dtype, "indices": indices_dtype} - validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) + validator.check_tensors_dtypes_same_and_valid( + args, mstype.int_type, self.name) out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype) return out_dtype @@ -172,7 +170,6 @@ class CacheSwapHashmap(PrimitiveWithInfer): outputs=['swap_cache_idx', 'old_emb_idx']) def infer_shape(self, hashmap_shape, miss_emb_idx_shape, step_shape): - if len(hashmap_shape) != 2: raise ValueError("The dimension of 'hashmap' in CacheSwapHashmap must be 2, " "but got %d." % len(hashmap_shape)) @@ -181,12 +178,13 @@ class CacheSwapHashmap(PrimitiveWithInfer): return out_shape def infer_dtype(self, hashmap_dtype, miss_emb_idx_dtype, step_dtype): - validator.check_tensor_dtype_valid("miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) + validator.check_tensor_dtype_valid( + "miss_emb_idx", miss_emb_idx_dtype, mstype.int_type, self.name) out_dtype = (miss_emb_idx_dtype, miss_emb_idx_dtype) return out_dtype -class CacheSwapTable(PrimitiveWithInfer): +class CacheSwapTable(PrimitiveWithCheck): """ Delete a hashmap entry,and insert a new key to hashmap, return the key and value of delete entry. @@ -212,21 +210,20 @@ class CacheSwapTable(PrimitiveWithInfer): self.init_prim_io_names(inputs=['cache_table', 'swap_cache_idx', 'miss_value'], outputs=['old_value']) - def infer_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): + def check_shape(self, cache_table_shape, swap_cache_idx_shape, miss_value_shape): if len(cache_table_shape) != 2: raise ValueError( "cache table shape must be 2, but got %d" % len(cache_table_shape)) - if swap_cache_idx_shape + cache_table_shape[1:] != miss_value_shape: - raise ValueError( - "swap_cache_idx_shape + cache_table_shape[1:] must equal to miss_value_shape") + return miss_value_shape - def infer_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): - validator.check_tensor_dtype_valid("swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) + def check_dtype(self, cache_table_dtype, swap_cache_idx_dtype, miss_value_dtype): + validator.check_tensor_dtype_valid( + "swap_cache_idx", swap_cache_idx_dtype, mstype.int_type, self.name) return miss_value_dtype -class MapCacheIdx(PrimitiveWithInfer): +class MapCacheIdx(PrimitiveWithCheck): """ MapCacheIdx merge SearchCacheIdx, CacheSwapHashmap, UpdateCache together. When input an indices tensor, it will output the cache indices which search in hashmap. @@ -244,21 +241,34 @@ class MapCacheIdx(PrimitiveWithInfer): def __init__(self): """init MapCacheIdx""" - self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'cache_max_num'], + self.init_prim_io_names(inputs=['hashmap', 'indices', 'step', 'emb_max_num', 'offset'], outputs=['cache_idx', 'old_emb_idx', 'miss_emb_idx', 'swap_cache_idx']) - def infer_shape(self, hashmap_shape, indices_shape, step_shape, emb_max_num_shape, cache_max_num_shape): - + def __check__(self, hashmap, indices, step, emb_max_num, offset): + hashmap_shape = hashmap['shape'] if len(hashmap_shape) != 2: raise ValueError("The dimension of 'hashmap' in SearchCacheIdx must be 2, " "but got %d." % len(hashmap_shape)) - out_shape = (indices_shape, indices_shape, - indices_shape, indices_shape) - return out_shape + out_shape = (indices['shape'], -1, -1, -1) - def infer_dtype(self, hashmap_dtype, indices_dtype, step_dtype, emb_max_num_dtype, cache_max_num_dtype): + hashmap_dtype = hashmap['dtype'] + indices_dtype = indices['dtype'] args = {"hashmap": hashmap_dtype, "indices": indices_dtype} - validator.check_tensors_dtypes_same_and_valid(args, mstype.int_type, self.name) + validator.check_tensor_type_same(args, mstype.int_type, self.name) out_dtype = (hashmap_dtype, hashmap_dtype, hashmap_dtype, hashmap_dtype) - return out_dtype + + out = {'shape': out_shape, + 'dtype': out_dtype, + 'value': None} + if 'max_shape' in indices: + out['max_shape'] = (indices['max_shape'], indices['max_shape'], + indices['max_shape'], indices['max_shape']) + else: + out['max_shape'] = (indices['shape'], indices['shape'], + indices['shape'], indices['shape']) + if 'min_shape' in indices: + out['min_shape'] = (indices['min_shape'], 0, 0, 0) + else: + out['min_shape'] = (0, 0, 0, 0) + return out diff --git a/tests/st/ops/cpu/test_cache_ops.py b/tests/st/ops/cpu/test_cache_ops.py index ef670c9ac4..3b30a534be 100644 --- a/tests/st/ops/cpu/test_cache_ops.py +++ b/tests/st/ops/cpu/test_cache_ops.py @@ -75,19 +75,6 @@ class CacheSwapHashmapNet(nn.Cell): return self.ops(self.net.hashmap, miss_emb_idx, self.step) -class MapCacheIdxNet(nn.Cell): - def __init__(self, hashmap_np): - super().__init__() - self.ops = P.MapCacheIdx() - self.hashmap = Parameter(Tensor(hashmap_np), name="hashmap") - self.emb_max = 25 - self.cache_max = 10 - self.step = 0 - - def construct(self, indices): - return self.ops(self.hashmap, indices, self.step, self.emb_max, self.cache_max) - - class UpdateCacheNet(nn.Cell): def __init__(self, x): super().__init__() @@ -165,45 +152,6 @@ def test_cache_swap_hashmap(): np.array(hashmap_np_after_ops, np.int32)) -@pytest.mark.level0 -@pytest.mark.platform_x86_cpu -@pytest.mark.env_onecard -def test_map_cache_idx(): - hashmap_np = init_hashmap(10) - indices_np = np.array([10, 2, 20, 5, 3], np.int32) - map_cache_idx = MapCacheIdxNet(hashmap_np) - indices = Tensor(indices_np) - cache_idx, old_emb_idx, miss_emb_idx, swap_cache_idx = map_cache_idx( - indices) - - expect_cache_idx = [5, 1, 9, 7, 3] - expect_old_emb_idx = [-1, -1, 21, 15, -1] - expect_miss_emb_idx = [-1, -1, 20, 5, -1] - expect_swap_cache_idx = [-1, -1, 9, 7, -1] - - hashmap_np_after_ops = [[5, 7, 0, 1], - [10, 5, 0, 1], - [2, 1, 0, 1], - [20, 9, 0, 1], - [20, 9, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [3, 3, 0, 1], - [21, 9, -5, 0]] - - assert np.allclose(cache_idx.asnumpy(), - np.array(expect_cache_idx, np.int32)) - assert np.allclose(old_emb_idx.asnumpy(), - np.array(expect_old_emb_idx, np.int32)) - assert np.allclose(miss_emb_idx.asnumpy(), - np.array(expect_miss_emb_idx, np.int32)) - assert np.allclose(swap_cache_idx.asnumpy(), - np.array(expect_swap_cache_idx, np.int32)) - assert np.allclose(map_cache_idx.hashmap.data.asnumpy(), - np.array(hashmap_np_after_ops, np.int32)) - - @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard