|
|
|
@ -82,8 +82,8 @@ int GatherInt8CPUKernel::DoGather(int task_id) {
|
|
|
|
int count = MSMIN(stride, outer_size - stride * task_id);
|
|
|
|
int count = MSMIN(stride, outer_size - stride * task_id);
|
|
|
|
auto thread_stride = stride * task_id;
|
|
|
|
auto thread_stride = stride * task_id;
|
|
|
|
|
|
|
|
|
|
|
|
input_ptr += thread_stride * limit;
|
|
|
|
input_ptr += thread_stride * inner_size * limit;
|
|
|
|
output_ptr += thread_stride * indices_element_size;
|
|
|
|
output_ptr += thread_stride * inner_size * indices_element_size;
|
|
|
|
return GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_ptr, indices_element_size, param_);
|
|
|
|
return GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_ptr, indices_element_size, param_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|