!9992 fix strided slice shrink axis mask

From: @zhaozhenlong
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
pull/9992/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 99fca84f94

@ -340,6 +340,7 @@ int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs
return RET_OK;
}
// note: begin, end, stride length are equal, but may less than rank of input
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);
if (outputs.size() != kStridedSliceOutputNum) {
@ -359,6 +360,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
auto inferflag = infer_flag();
in_shape_.clear();
if (inferflag) {
in_shape_.assign(input_shape.begin(), input_shape.end());
}
begins_.clear();
ends_.clear();
strides_.clear();
@ -366,9 +370,6 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
ndim_ = static_cast<int>(GetBegin().size());
for (int i = 0; i < ndim_; i++) {
if (inferflag) {
in_shape_.emplace_back(input_shape.at(i));
}
begins_.emplace_back((GetBegin()).at(i));
ends_.emplace_back((GetEnd()).at(i));
strides_.emplace_back((GetStride()).at(i));
@ -391,9 +392,6 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
}
ndim_ = begin_tensor->ElementsNum();
for (int i = 0; i < ndim_; ++i) {
if (inferflag) {
in_shape_.emplace_back(input_shape.at(i));
}
begins_.emplace_back(begin_data[i]);
ends_.emplace_back(end_data[i]);
strides_.emplace_back(stride_data[i]);
@ -431,22 +429,16 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
if (!inferflag) {
return RET_OK;
}
std::vector<int> output_shape;
output_shape.clear();
output_shape.resize(in_shape_.size());
std::vector<int> output_shape(in_shape_);
TransIndexToPositive();
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) {
if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1;
} else {
if (strides_.at(i) == 0) {
MS_LOG(ERROR) << "strides should not be 0.";
return RET_INFER_ERR;
}
output_shape.at(i) =
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
for (int i = 0; i < ndim_; i++) {
if (strides_.at(i) == 0) {
MS_LOG(ERROR) << "strides should not be 0.";
return RET_INFER_ERR;
}
output_shape.at(i) =
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
}
output_shape = ApplyShrinkMask(output_shape);

Loading…
Cancel
Save