!5850 [MSLITE] strided slice support neg strides

Merge pull request !5850 from zhaozhenlong/lite/issue/strided_slice
pull/5850/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0aad6800e7

@ -22,14 +22,22 @@ void PadStridedSliceParameterTo4D(StridedSliceParameter *param) {
int32_t ends[DIMENSION_4D];
int32_t strides[DIMENSION_4D];
int32_t input_shape[DIMENSION_4D];
for (int32_t i = 0; i < param->num_axes_; ++i) {
int32_t i;
for (i = 0; i < param->num_axes_; ++i) {
begins[i] = param->begins_[i];
ends[i] = param->ends_[i];
ends[i] = MSMIN(param->ends_[i], param->in_shape_[i]);
strides[i] = param->strides_[i];
input_shape[i] = param->in_shape_[i];
}
int32_t real_index = param->num_axes_ - 1;
for (int32_t i = DIMENSION_4D - 1; i >= 0; --i) {
for (i = param->num_axes_; i < param->in_shape_length_; ++i) {
input_shape[i] = param->in_shape_[i];
begins[i] = 0;
ends[i] = param->in_shape_[i];
strides[i] = 1;
}
int32_t real_index = param->in_shape_length_ - 1;
for (i = DIMENSION_4D - 1; i >= 0; --i) {
if (real_index >= 0) {
param->begins_[i] = begins[real_index];
param->ends_[i] = ends[real_index];
@ -43,8 +51,23 @@ void PadStridedSliceParameterTo4D(StridedSliceParameter *param) {
}
}
param->num_axes_ = DIMENSION_4D;
param->in_shape_length_ = DIMENSION_4D;
}
void ChangeNegToPositive(StridedSliceParameter *param) {
int i;
for (i = 0; i < DIMENSION_4D; ++i) {
if (param->begins_[i] < 0) {
param->begins_[i] += param->in_shape_[i];
}
if (param->ends_[i] < 0) {
param->ends_[i] += param->in_shape_[i];
}
}
}
inline bool LoopContinue(int stride, int i, int end) { return stride > 0 ? i < end : i > end; }
int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *param) {
if (in_data == NULL || out_data == NULL || param == NULL) {
return NNACL_NULL_PTR;
@ -61,16 +84,18 @@ int DoStridedSlice(const void *in_data, void *out_data, StridedSliceParameter *p
if (param->num_axes_ < DIMENSION_4D) {
PadStridedSliceParameterTo4D(param);
}
ChangeNegToPositive(param);
size_t dim_offset[DIMENSION_4D - 1];
dim_offset[2] = in_shape[3];
dim_offset[1] = dim_offset[2] * in_shape[2];
dim_offset[0] = dim_offset[1] * in_shape[1];
size_t out_offset = 0;
for (int32_t dim0 = begins[0]; dim0 < ends[0]; dim0 += strides[0]) {
for (int32_t dim1 = begins[1]; dim1 < ends[1]; dim1 += strides[1]) {
for (int32_t dim2 = begins[2]; dim2 < ends[2]; dim2 += strides[2]) {
for (int32_t dim3 = begins[3]; dim3 < ends[3]; dim3 += strides[3]) {
int32_t dim0, dim1, dim2, dim3;
for (dim0 = begins[0]; LoopContinue(strides[0], dim0, ends[0]); dim0 += strides[0]) {
for (dim1 = begins[1]; LoopContinue(strides[1], dim1, ends[1]); dim1 += strides[1]) {
for (dim2 = begins[2]; LoopContinue(strides[2], dim2, ends[2]); dim2 += strides[2]) {
for (dim3 = begins[3]; LoopContinue(strides[3], dim3, ends[3]); dim3 += strides[3]) {
int32_t in_offset = dim0 * dim_offset[0] + dim1 * dim_offset[1] + dim2 * dim_offset[2] + dim3;
if (param->data_type == kDataTypeFloat) {
*((float *)out_data + out_offset) = *((float *)in_data + in_offset);

@ -25,6 +25,7 @@ typedef struct StridedSliceParameter {
int strides_[8];
int isScale;
int num_axes_;
int in_shape_length_;
int in_shape_[8];
LiteDataType data_type;
} StridedSliceParameter;

@ -44,6 +44,11 @@ int StridedSliceCPUKernel::ReSize() {
MS_ASSERT(input);
MS_ASSERT(parameter);
parameter->data_type = input->data_type() == kNumberTypeInt8 ? kDataTypeInt8 : kDataTypeFloat;
auto input_shape = input->shape();
for (size_t i = 0; i < input_shape.size(); ++i) {
parameter->in_shape_[i] = input_shape[i];
}
parameter->in_shape_length_ = static_cast<int>(input_shape.size());
return RET_OK;
}

Loading…
Cancel
Save