!9220 [MS][LITE]add reduce_all mode

From: @YeFeng_24
Reviewed-by: @zhanghaibo5,@hangangqiang,@hangangqiang
Signed-off-by:
pull/9220/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit af1022a587

@ -144,6 +144,29 @@ int IntReduceMin(int outer_size, int inner_size, int axis_size, const int *src_d
} }
return NNACL_OK; return NNACL_OK;
} }
int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid,
int thread_num) {
if (src_data == NULL || dst_data == NULL) {
return NNACL_NULL_PTR;
}
int i, j, k;
for (j = tid; j < outer_size; j += thread_num) {
const bool *outer_src = src_data + j * axis_size * inner_size;
bool *outer_dst = dst_data + j * inner_size;
for (k = 0; k < inner_size; k++) {
const bool *inner_src = outer_src + k;
bool *inner_dst = outer_dst + k;
bool tmp = true;
for (i = 0; i < axis_size; i++) {
tmp = tmp && inner_src[i * inner_size];
}
*inner_dst = tmp;
}
}
return NNACL_OK;
}
int ReduceProd(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, int ReduceProd(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num) { int thread_num) {
if (src_data == NULL || dst_data == NULL) { if (src_data == NULL || dst_data == NULL) {

@ -38,6 +38,8 @@ int IntReduceProd(int outer_size, int inner_size, int axis_size, const int *src_
int thread_num); int thread_num);
int ReduceSumSquare(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, int ReduceSumSquare(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid,
int thread_num); int thread_num);
int ReduceAll(int outer_size, int inner_size, int axis_size, const bool *src_data, bool *dst_data, int tid,
int thread_num);
#ifdef ENABLE_NNACL_INFER_SHAPE #ifdef ENABLE_NNACL_INFER_SHAPE
int ReduceInferShape(int **in_shape, size_t *dim_size, int *out_shape, int *in_format, int *out_format, int ReduceInferShape(int **in_shape, size_t *dim_size, int *out_shape, int *in_format, int *out_format,

@ -63,6 +63,7 @@ typedef enum LiteDataType {
kDataTypeFloat, kDataTypeFloat,
kDataTypeInt, kDataTypeInt,
kDataTypeInt8, kDataTypeInt8,
KDataTypeBool,
} LiteDataType; } LiteDataType;
typedef enum DataOrder { typedef enum DataOrder {

@ -765,7 +765,8 @@ enum ReduceMode : byte {
ReduceProd = 3, ReduceProd = 3,
ReduceSum = 4, ReduceSum = 4,
ReduceSumSquare = 5, ReduceSumSquare = 5,
ReduceASum = 6 ReduceASum = 6,
ReduceAll = 7
} }
table Reduce { table Reduce {

@ -67,6 +67,11 @@ int Reduce::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inp
attr->mode = schema::ReduceMode_ReduceProd; attr->mode = schema::ReduceMode_ReduceProd;
} else if (prim.name() == "ReduceSumSquare") { } else if (prim.name() == "ReduceSumSquare") {
attr->mode = schema::ReduceMode_ReduceSumSquare; attr->mode = schema::ReduceMode_ReduceSumSquare;
} else if (prim.name() == "ReduceAll") {
attr->mode = schema::ReduceMode_ReduceAll;
} else {
MS_LOG(ERROR) << "Not supported reduce mode: " << prim.name();
return RET_ERROR;
} }
attr->keepDims = GetValue<bool>(prim.GetAttr("keep_dims")); attr->keepDims = GetValue<bool>(prim.GetAttr("keep_dims"));

@ -31,6 +31,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Mean; using mindspore::schema::PrimitiveType_Mean;
using mindspore::schema::PrimitiveType_Reduce; using mindspore::schema::PrimitiveType_Reduce;
using mindspore::schema::ReduceMode; using mindspore::schema::ReduceMode;
using mindspore::schema::ReduceMode_ReduceAll;
using mindspore::schema::ReduceMode_ReduceASum; using mindspore::schema::ReduceMode_ReduceASum;
using mindspore::schema::ReduceMode_ReduceMax; using mindspore::schema::ReduceMode_ReduceMax;
using mindspore::schema::ReduceMode_ReduceMean; using mindspore::schema::ReduceMode_ReduceMean;
@ -78,6 +79,10 @@ int ReduceCPUKernel::Init() {
reducer_ = ReduceSum; reducer_ = ReduceSum;
break; break;
} }
case static_cast<int>(ReduceMode_ReduceAll): {
bool_reducer_ = ReduceAll;
break;
}
default: default:
MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_;
return RET_ERROR; return RET_ERROR;
@ -96,6 +101,9 @@ int ReduceCPUKernel::CallReduceUnit(int task_id) {
if (data_type_ == kDataTypeFloat) { if (data_type_ == kDataTypeFloat) {
ret = reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_), ret = reducer_(outer_size_, inner_size_, axis_size_, static_cast<const float *>(src_data_),
static_cast<float *>(dst_data_), task_id, context_->thread_num_); static_cast<float *>(dst_data_), task_id, context_->thread_num_);
} else if (data_type_ == KDataTypeBool) {
ret = bool_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const bool *>(src_data_),
static_cast<bool *>(dst_data_), task_id, context_->thread_num_);
} else { } else {
ret = int_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const int *>(src_data_), ret = int_reducer_(outer_size_, inner_size_, axis_size_, static_cast<const int *>(src_data_),
static_cast<int *>(dst_data_), task_id, context_->thread_num_); static_cast<int *>(dst_data_), task_id, context_->thread_num_);
@ -117,6 +125,8 @@ int ReduceImpl(void *cdata, int task_id) {
int ReduceCPUKernel::Run() { int ReduceCPUKernel::Run() {
if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) { if (in_tensors().at(0)->data_type() == kNumberTypeFloat32) {
data_type_ = kDataTypeFloat; data_type_ = kDataTypeFloat;
} else if (in_tensors().at(0)->data_type() == kNumberTypeBool) {
data_type_ = KDataTypeBool;
} else { } else {
data_type_ = kDataTypeInt; data_type_ = kDataTypeInt;
} }
@ -202,6 +212,8 @@ int ReduceCPUKernel::MallocTmpBuffer() {
void *buffer = nullptr; void *buffer = nullptr;
if (data_type_ == kDataTypeFloat) { if (data_type_ == kDataTypeFloat) {
buffer = context_->allocator->Malloc(size * sizeof(float)); buffer = context_->allocator->Malloc(size * sizeof(float));
} else if (data_type_ == KDataTypeBool) {
buffer = context_->allocator->Malloc(size * sizeof(bool));
} else { } else {
buffer = context_->allocator->Malloc(size * sizeof(int)); buffer = context_->allocator->Malloc(size * sizeof(int));
} }

@ -31,6 +31,8 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {
float *dst_data, const int tid, const int thread_num); float *dst_data, const int tid, const int thread_num);
typedef int (*IntReducer)(const int outer_size, const int inner_size, const int axis_size, const int *src_data, typedef int (*IntReducer)(const int outer_size, const int inner_size, const int axis_size, const int *src_data,
int *dst_data, const int tid, const int thread_num); int *dst_data, const int tid, const int thread_num);
typedef int (*BoolReducer)(const int outer_size, const int inner_size, const int axis_size, const bool *src_data,
bool *dst_data, const int tid, const int thread_num);
public: public:
ReduceCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs, ReduceCPUKernel(OpParameter *param, const std::vector<lite::Tensor *> &inputs,
@ -54,6 +56,7 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel {
private: private:
ReduceParameter *reduce_param_; ReduceParameter *reduce_param_;
Reducer reducer_ = nullptr; Reducer reducer_ = nullptr;
BoolReducer bool_reducer_ = nullptr;
IntReducer int_reducer_ = nullptr; IntReducer int_reducer_ = nullptr;
std::vector<void *> data_buffers_; std::vector<void *> data_buffers_;
LiteDataType data_type_; LiteDataType data_type_;

@ -52,6 +52,8 @@ STATUS TFReduceParser::Parse(const tensorflow::NodeDef &tf_op,
attr->mode = schema::ReduceMode_ReduceMean; attr->mode = schema::ReduceMode_ReduceMean;
} else if (tf_op.op() == "Prod") { } else if (tf_op.op() == "Prod") {
attr->mode = schema::ReduceMode_ReduceProd; attr->mode = schema::ReduceMode_ReduceProd;
} else if (tf_op.op() == "All") {
attr->mode = schema::ReduceMode_ReduceAll;
} else { } else {
MS_LOG(ERROR) << "unsupported reduce mode: " << tf_op.op(); MS_LOG(ERROR) << "unsupported reduce mode: " << tf_op.op();
return RET_ERROR; return RET_ERROR;
@ -106,5 +108,6 @@ TFNodeRegistrar g_tfMaxParser("Max", new TFReduceParser());
TFNodeRegistrar g_tfMinParser("Min", new TFReduceParser()); TFNodeRegistrar g_tfMinParser("Min", new TFReduceParser());
TFNodeRegistrar g_tfMeanParser("Mean", new TFReduceParser()); TFNodeRegistrar g_tfMeanParser("Mean", new TFReduceParser());
TFNodeRegistrar g_tfProdParser("Prod", new TFReduceParser()); TFNodeRegistrar g_tfProdParser("Prod", new TFReduceParser());
TFNodeRegistrar g_tfAllParser("All", new TFReduceParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

Loading…
Cancel
Save