From 75b6f82a296fb9a5e920e592882ed8c7047dc632 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Wed, 16 Sep 2020 15:47:55 +0800 Subject: [PATCH] fix strided slice infer shape error when neg stride --- mindspore/lite/src/ops/strided_slice.cc | 19 +++++++++++++++---- mindspore/lite/src/ops/strided_slice.h | 1 + 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 94b3e8f3d4..6cb0356c9b 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -226,6 +226,17 @@ void StridedSlice::ApplyEndMask() { } } +void StridedSlice::TransIndexToPositive() { + for (int i = 0; i < static_cast(begins_.size()); ++i) { + if (begins_.at(i) < 0) { + begins_.at(i) += in_shape_.at(i); + } + if (ends_.at(i) < 0) { + ends_.at(i) += in_shape_.at(i); + } + } +} + int StridedSlice::InferShape(std::vector inputs, std::vector outputs) { MS_ASSERT(this->primitive_ != nullptr); if (outputs.size() != kStridedSliceOutputNum) { @@ -266,7 +277,7 @@ int StridedSlice::InferShape(std::vector inputs, std::vectorElementsNum(); - for (int i=0; i< ndim_; ++i) { + for (int i = 0; i < ndim_; ++i) { in_shape_.emplace_back(input_shape.at(i)); begins_.emplace_back(begin_data[i]); ends_.emplace_back(end_data[i]); @@ -297,13 +308,13 @@ int StridedSlice::InferShape(std::vector inputs, std::vector(in_shape_.size()); i++) { if (i < ndim_ && new_axis_mask_.at(i)) { output_shape.at(i) = 1; - } else if (ends_.at(i) > 0) { - output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); } else { - output_shape.at(i) = (input_shape.at(i) + ends_.at(i) - begins_.at(i)) % input_shape.at(i) / strides_.at(i); + output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); } } diff --git a/mindspore/lite/src/ops/strided_slice.h b/mindspore/lite/src/ops/strided_slice.h index c135aab665..cd9cdb2ee6 100644 --- a/mindspore/lite/src/ops/strided_slice.h +++ b/mindspore/lite/src/ops/strided_slice.h @@ -80,6 +80,7 @@ class StridedSlice : public PrimitiveC { std::vector ellipsis_mask_; std::vector new_axis_mask_; std::vector shrink_axis_mask_; + void TransIndexToPositive(); }; } // namespace lite } // namespace mindspore