|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "src/ops/strided_slice.h"
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
|
|
|
|
#ifndef PRIMITIVE_WRITEABLE
|
|
|
|
|
#include "src/ops/ops_register.h"
|
|
|
|
@ -172,7 +173,8 @@ Registry StridedSliceRegistry(schema::PrimitiveType_StridedSlice, StridedSliceCr
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr size_t kStridedSliceOutputNum = 1;
|
|
|
|
|
constexpr size_t kStridedSliceInputNum = 1;
|
|
|
|
|
constexpr size_t kStridedSliceMultiInputNum = 4;
|
|
|
|
|
constexpr size_t kStridedSliceMultiInputNumMin = 3;
|
|
|
|
|
constexpr size_t kStridedSliceMultiInputNumMax = 5;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void StridedSlice::ApplyNewAxisMask() {
|
|
|
|
@ -251,13 +253,91 @@ void StridedSlice::TransIndexToPositive() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int StridedSlice::HandleAxesInputExist(const std::vector<lite::Tensor *> &inputs) {
|
|
|
|
|
// when axes input exist:
|
|
|
|
|
// input order: data, begin, end, axes(opt), stride(opt)
|
|
|
|
|
auto input_tensor = inputs.at(0);
|
|
|
|
|
MS_ASSERT(input_tensor != nullptr);
|
|
|
|
|
auto begin_tensor = inputs.at(1);
|
|
|
|
|
MS_ASSERT(begin_tensor != nullptr);
|
|
|
|
|
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
|
|
|
|
|
auto end_tensor = inputs.at(2);
|
|
|
|
|
MS_ASSERT(end_tensor != nullptr);
|
|
|
|
|
int *end_data = reinterpret_cast<int *>(end_tensor->MutableData());
|
|
|
|
|
if (begin_data == nullptr || end_data == nullptr) {
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
// when input contains axes, begins, ends, strides will be expand to the same length as input rank
|
|
|
|
|
ndim_ = static_cast<int>(input_tensor->shape().size());
|
|
|
|
|
int begin_ndim = begin_tensor->ElementsNum();
|
|
|
|
|
|
|
|
|
|
int *axes_data = nullptr;
|
|
|
|
|
auto axes_tensor = inputs.at(3);
|
|
|
|
|
if (axes_tensor->ElementsNum() != 0) {
|
|
|
|
|
MS_ASSERT(axes_tensor->ElementsNum() == begin_ndim);
|
|
|
|
|
axes_data = reinterpret_cast<int *>(axes_tensor->MutableData());
|
|
|
|
|
if (axes_data == nullptr) {
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int *stride_data = nullptr;
|
|
|
|
|
auto stride_tensor = inputs.at(4);
|
|
|
|
|
if (stride_tensor->ElementsNum() != 0) {
|
|
|
|
|
MS_ASSERT(stride_tensor->ElementsNum() == begin_ndim);
|
|
|
|
|
stride_data = reinterpret_cast<int *>(stride_tensor->MutableData());
|
|
|
|
|
if (stride_data == nullptr) {
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> axes;
|
|
|
|
|
if (axes_data == nullptr) {
|
|
|
|
|
for (int i = 0; i < begin_ndim; ++i) {
|
|
|
|
|
axes[i] = i;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
axes.assign(axes_data, axes_data + begin_ndim);
|
|
|
|
|
for (int i = 0; i < begin_ndim; ++i) {
|
|
|
|
|
if (axes[i] < 0) {
|
|
|
|
|
axes[i] += ndim_;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
in_shape_.assign(ndim_, 0);
|
|
|
|
|
begins_.assign(ndim_, 0);
|
|
|
|
|
ends_.assign(ndim_, 0);
|
|
|
|
|
strides_.assign(ndim_, 0);
|
|
|
|
|
auto input_shape = input_tensor->shape();
|
|
|
|
|
for (int i = 0; i < ndim_; ++i) {
|
|
|
|
|
in_shape_[i] = input_shape.at(i);
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < ndim_; ++i) {
|
|
|
|
|
auto axes_it = std::find(axes.begin(), axes.end(), i);
|
|
|
|
|
if (axes_it != axes.end()) {
|
|
|
|
|
auto axis = axes_it - axes.begin();
|
|
|
|
|
// begins or ends exceed limit will be set to limit
|
|
|
|
|
begins_[i] = std::max(std::min(begin_data[axis], input_shape[i] - 1), -input_shape[i]);
|
|
|
|
|
ends_[i] = std::max(std::min(end_data[axis], input_shape[i]), -input_shape[i] - 1);
|
|
|
|
|
strides_[i] = stride_data[axis];
|
|
|
|
|
} else {
|
|
|
|
|
begins_[i] = 0;
|
|
|
|
|
ends_[i] = input_shape[i];
|
|
|
|
|
strides_[i] = 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs) {
|
|
|
|
|
MS_ASSERT(this->primitive_ != nullptr);
|
|
|
|
|
if (outputs.size() != kStridedSliceOutputNum) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
if (inputs.size() != kStridedSliceInputNum && inputs.size() != kStridedSliceMultiInputNum) {
|
|
|
|
|
if (inputs.size() != kStridedSliceInputNum &&
|
|
|
|
|
!(inputs.size() <= kStridedSliceMultiInputNumMax && inputs.size() >= kStridedSliceMultiInputNumMin)) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
@ -268,6 +348,10 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
auto input_shape = input->shape();
|
|
|
|
|
auto inferflag = GetInferFlag();
|
|
|
|
|
|
|
|
|
|
in_shape_.clear();
|
|
|
|
|
begins_.clear();
|
|
|
|
|
ends_.clear();
|
|
|
|
|
strides_.clear();
|
|
|
|
|
if (inputs.size() == kStridedSliceInputNum) {
|
|
|
|
|
ndim_ = static_cast<int>(GetBegin().size());
|
|
|
|
|
|
|
|
|
@ -279,7 +363,9 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
ends_.emplace_back((GetEnd())[i]);
|
|
|
|
|
strides_.emplace_back((GetStride())[i]);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
if (inputs.size() == 4) {
|
|
|
|
|
// input order: input, begins, ends, strides.
|
|
|
|
|
auto begin_tensor = inputs.at(1);
|
|
|
|
|
int *begin_data = reinterpret_cast<int *>(begin_tensor->MutableData());
|
|
|
|
|
auto end_tensor = inputs.at(2);
|
|
|
|
@ -299,6 +385,13 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
strides_.emplace_back(stride_data[i]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (inputs.size() == 5) {
|
|
|
|
|
// input order: input, begins, end, axes, strides
|
|
|
|
|
auto ret = HandleAxesInputExist(inputs);
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set all mask to original input shape
|
|
|
|
|
begins_mask_.resize(ndim_);
|
|
|
|
@ -333,7 +426,12 @@ int StridedSlice::InferShape(std::vector<lite::Tensor *> inputs, std::vector<lit
|
|
|
|
|
if (i < ndim_ && new_axis_mask_.at(i)) {
|
|
|
|
|
output_shape.at(i) = 1;
|
|
|
|
|
} else {
|
|
|
|
|
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
|
|
|
|
|
if (strides_.at(i) == 0) {
|
|
|
|
|
MS_LOG(ERROR) << "strides should not be 0.";
|
|
|
|
|
return RET_INFER_ERR;
|
|
|
|
|
}
|
|
|
|
|
output_shape.at(i) =
|
|
|
|
|
(ends_.at(i) - begins_.at(i) + strides_.at(i) + (strides_.at(i) < 0 ? 1 : -1)) / strides_.at(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|