|
|
|
@ -65,7 +65,7 @@ int ReduceFp16CPUKernel::CallReduceUnit(int task_id) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static int ReduceImpl(void *cdata, int task_id) {
|
|
|
|
|
static int ReduceFp16Impl(void *cdata, int task_id) {
|
|
|
|
|
auto reduce = reinterpret_cast<ReduceFp16CPUKernel *>(cdata);
|
|
|
|
|
auto error_code = reduce->CallReduceUnit(task_id);
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
@ -102,7 +102,7 @@ int ReduceFp16CPUKernel::Run() {
|
|
|
|
|
outer_size_ = outer_sizes_[i];
|
|
|
|
|
inner_size_ = inner_sizes_[i];
|
|
|
|
|
axis_size_ = axis_sizes_[i];
|
|
|
|
|
auto error_code = ParallelLaunch(this->context_->thread_pool_, ReduceImpl, this, context_->thread_num_);
|
|
|
|
|
auto error_code = ParallelLaunch(this->context_->thread_pool_, ReduceFp16Impl, this, context_->thread_num_);
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|
FreeTmpBuffer();
|
|
|
|
|
MS_LOG(ERROR) << "Reduce run error, error_code[" << error_code << "]";
|
|
|
|
|