|
|
|
@ -340,6 +340,7 @@ int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// note: begin, end, stride length are equal, but may less than rank of input
|
|
|
|
|
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
|
|
|
|
MS_ASSERT(this->primitive_ != nullptr);
|
|
|
|
|
if (outputs.size() != kStridedSliceOutputNum) {
|
|
|
|
@ -359,6 +360,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
auto inferflag = infer_flag();
|
|
|
|
|
|
|
|
|
|
in_shape_.clear();
|
|
|
|
|
if (inferflag) {
|
|
|
|
|
in_shape_.assign(input_shape.begin(), input_shape.end());
|
|
|
|
|
}
|
|
|
|
|
begins_.clear();
|
|
|
|
|
ends_.clear();
|
|
|
|
|
strides_.clear();
|
|
|
|
@ -366,9 +370,6 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
ndim_ = static_cast<int>(GetBegin().size());
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
if (inferflag) {
|
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
}
|
|
|
|
|
begins_.emplace_back((GetBegin()).at(i));
|
|
|
|
|
ends_.emplace_back((GetEnd()).at(i));
|
|
|
|
|
strides_.emplace_back((GetStride()).at(i));
|
|
|
|
@ -391,9 +392,6 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
}
|
|
|
|
|
ndim_ = begin_tensor->ElementsNum();
|
|
|
|
|
for (int i = 0; i < ndim_; ++i) {
|
|
|
|
|
if (inferflag) {
|
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
}
|
|
|
|
|
begins_.emplace_back(begin_data[i]);
|
|
|
|
|
ends_.emplace_back(end_data[i]);
|
|
|
|
|
strides_.emplace_back(stride_data[i]);
|
|
|
|
@ -431,22 +429,16 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
if (!inferflag) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> output_shape;
|
|
|
|
|
output_shape.clear();
|
|
|
|
|
output_shape.resize(in_shape_.size());
|
|
|
|
|
std::vector<int> output_shape(in_shape_);
|
|
|
|
|
|
|
|
|
|
TransIndexToPositive();
|
|
|
|
|
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) {
|
|
|
|
|
if (i < ndim_ && new_axis_mask_.at(i)) {
|
|
|
|
|
output_shape.at(i) = 1;
|
|
|
|
|
} else {
|
|
|
|
|
if (strides_.at(i) == 0) {
|
|
|
|
|
MS_LOG(ERROR) << "strides should not be 0.";
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
output_shape.at(i) =
|
|
|
|
|
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
if (strides_.at(i) == 0) {
|
|
|
|
|
MS_LOG(ERROR) << "strides should not be 0.";
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
output_shape.at(i) =
|
|
|
|
|
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output_shape = ApplyShrinkMask(output_shape);
|
|
|
|
|