diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc index 578ab4b963..25f8504614 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.cc @@ -23,19 +23,20 @@ namespace mindspore { namespace kernel { namespace { template -void LookUpTableTask(const float *input_addr, const T *indices_addr, const float *output_max_addr, float *output_addr, - size_t indices_lens, size_t outer_dim_size, T offset, size_t first_dim_size) { - size_t lens = outer_dim_size * sizeof(float); +void LookUpTableTask(const float *input_addr, const T *indices_addr, float *output_addr, size_t indices_lens, + size_t outer_dim_size, T offset, size_t first_dim_size) { + auto type_size = sizeof(float); + size_t lens = outer_dim_size * type_size; for (size_t i = 0; i < indices_lens; ++i) { T index = indices_addr[i] - offset; if (index >= 0 && index < SizeToInt(first_dim_size)) { size_t pos = index * outer_dim_size; - auto ret = memcpy_s(output_addr, output_max_addr - output_addr, input_addr + pos, lens); + auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); if (ret != EOK) { MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; } } else { - auto ret = memset_s(output_addr, output_max_addr - output_addr, 0, lens); + auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); if (ret != EOK) { MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; } @@ -82,7 +83,7 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector, input_addr, indices_addr + task_offset, output_addr + outputs[0]->size, + threads[i] = std::thread(LookUpTableTask, input_addr, indices_addr + task_offset, output_addr + task_offset * outer_dim_size_, task_proc_lens, outer_dim_size_, offset_, first_dim_size_); task_offset += task_proc_lens;