|
|
|
@ -23,19 +23,19 @@ namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
namespace {
|
|
|
|
|
template <typename T>
|
|
|
|
|
void LookUpTableTask(const float *input_addr, const T *indices_addr, float *output_addr, size_t indices_lens,
|
|
|
|
|
size_t outer_dim_size, T offset, size_t first_dim_size) {
|
|
|
|
|
void LookUpTableTask(const float *input_addr, const T *indices_addr, float *output_addr, float *output_max_addr,
|
|
|
|
|
size_t indices_lens, size_t outer_dim_size, T offset, size_t first_dim_size) {
|
|
|
|
|
size_t lens = outer_dim_size * sizeof(float);
|
|
|
|
|
for (size_t i = 0; i < indices_lens; ++i) {
|
|
|
|
|
T index = indices_addr[i] - offset;
|
|
|
|
|
if (index >= 0 && index < SizeToInt(first_dim_size)) {
|
|
|
|
|
size_t pos = index * outer_dim_size;
|
|
|
|
|
auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens);
|
|
|
|
|
auto ret = memcpy_s(output_addr, output_max_addr - output_addr, input_addr + pos, lens);
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto ret = memset_s(output_addr, lens, 0, lens);
|
|
|
|
|
auto ret = memset_s(output_addr, output_max_addr - output_addr, 0, lens);
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "LookUpTable task memset failed.";
|
|
|
|
|
}
|
|
|
|
@ -83,8 +83,8 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens;
|
|
|
|
|
threads[i] = std::thread(LookUpTableTask<T>, input_addr, indices_addr + task_offset,
|
|
|
|
|
output_addr + task_offset * outer_dim_size_, task_proc_lens, outer_dim_size_, offset_,
|
|
|
|
|
first_dim_size_);
|
|
|
|
|
output_addr + task_offset * outer_dim_size_, output_addr + outputs[0]->size,
|
|
|
|
|
task_proc_lens, outer_dim_size_, offset_, first_dim_size_);
|
|
|
|
|
task_offset += task_proc_lens;
|
|
|
|
|
if (task_offset + task_proc_lens > indices_lens_) {
|
|
|
|
|
task_proc_lens = indices_lens_ - task_offset;
|
|
|
|
|