|
|
|
@ -31,6 +31,7 @@ using mindspore::lite::RET_OK;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Mean;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Reduce;
|
|
|
|
|
using mindspore::schema::ReduceMode;
|
|
|
|
|
using mindspore::schema::ReduceMode_ReduceAll;
|
|
|
|
|
using mindspore::schema::ReduceMode_ReduceASum;
|
|
|
|
|
using mindspore::schema::ReduceMode_ReduceMax;
|
|
|
|
|
using mindspore::schema::ReduceMode_ReduceMean;
|
|
|
|
@ -78,6 +79,10 @@ int ReduceCPUKernel::Init() {
|
|
|
|
|
reducer_ = ReduceSum;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case static_cast<int>(ReduceMode_ReduceAll): {
|
|
|
|
|
bool_reducer_ = ReduceAll;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -96,6 +101,9 @@ int ReduceCPUKernel::CallReduceUnit(int task_id) {
|
|
|
|
|
if (data_type_ == kDataTypeFloat) {
|
|
|
|
|
ret = reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_),
|
|
|
|
|
static_cast<float *>(dst_data_), task_id, context_->thread_num_);
|
|
|
|
|
} else if (data_type_ == KDataTypeBool) {
|
|
|
|
|
ret = bool_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const bool *>(src_data_),
|
|
|
|
|
static_cast<bool *>(dst_data_), task_id, context_->thread_num_);
|
|
|
|
|
} else {
|
|
|
|
|
ret = int_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const int *>(src_data_),
|
|
|
|
|
static_cast<int *>(dst_data_), task_id, context_->thread_num_);
|
|
|
|
@ -117,6 +125,8 @@ int ReduceImpl(void *cdata, int task_id) {
|
|
|
|
|
int ReduceCPUKernel::Run() {
|
|
|
|
|
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
|
|
|
|
|
data_type_ = kDataTypeFloat;
|
|
|
|
|
} else if (in_tensors().at(0)->data_type() == kNumberTypeBool) {
|
|
|
|
|
data_type_ = KDataTypeBool;
|
|
|
|
|
} else {
|
|
|
|
|
data_type_ = kDataTypeInt;
|
|
|
|
|
}
|
|
|
|
@ -202,6 +212,8 @@ int ReduceCPUKernel::MallocTmpBuffer() {
|
|
|
|
|
void *buffer = nullptr;
|
|
|
|
|
if (data_type_ == kDataTypeFloat) {
|
|
|
|
|
buffer = context_->allocator->Malloc(size * sizeof(float));
|
|
|
|
|
} else if (data_type_ == KDataTypeBool) {
|
|
|
|
|
buffer = context_->allocator->Malloc(size * sizeof(bool));
|
|
|
|
|
} else {
|
|
|
|
|
buffer = context_->allocator->Malloc(size * sizeof(int));
|
|
|
|
|
}
|
|
|
|
|