|
|
|
@ -16,7 +16,6 @@
|
|
|
|
|
|
|
|
|
|
#include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h"
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#include "runtime/device/cpu/cpu_device_address.h"
|
|
|
|
|
#include "common/thread_pool.h"
|
|
|
|
|
|
|
|
|
@ -79,37 +78,17 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in
|
|
|
|
|
MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
|
|
|
|
if (unit_num_ < thread_num) {
|
|
|
|
|
thread_num = unit_num_;
|
|
|
|
|
}
|
|
|
|
|
std::vector<common::Task> tasks;
|
|
|
|
|
tasks.reserve(thread_num);
|
|
|
|
|
auto task = [&](size_t start, size_t end) {
|
|
|
|
|
for (size_t c = 0; c < ceil(static_cast<double>(unit_num_) / thread_num); ++c) {
|
|
|
|
|
if (c * thread_num + start >= unit_num_) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
size_t i = c * thread_num + start;
|
|
|
|
|
size_t j = i / input_dim1_;
|
|
|
|
|
size_t k = i % input_dim1_;
|
|
|
|
|
for (size_t i = 0; i < unit_num_; ++i) {
|
|
|
|
|
size_t j = i / input_dim1_;
|
|
|
|
|
size_t k = i % input_dim1_;
|
|
|
|
|
|
|
|
|
|
T index = indices_addr[j];
|
|
|
|
|
if (index < 0 || index >= SizeToInt(output_dim0_)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
size_t output_index = index * output_dim1_ + k;
|
|
|
|
|
output_addr[output_index] += input_addr[i];
|
|
|
|
|
T index = indices_addr[j];
|
|
|
|
|
if (index < 0 || index >= SizeToInt(output_dim0_)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
for (size_t t = 0; t < thread_num; ++t) {
|
|
|
|
|
auto block = [&, t]() {
|
|
|
|
|
task(t, t + 1);
|
|
|
|
|
return common::SUCCESS;
|
|
|
|
|
};
|
|
|
|
|
tasks.emplace_back(block);
|
|
|
|
|
size_t output_index = index * output_dim1_ + k;
|
|
|
|
|
output_addr[output_index] += input_addr[i];
|
|
|
|
|
}
|
|
|
|
|
common::ThreadPool::GetInstance().SyncRun(tasks);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
} // namespace kernel
|
|
|
|
|