|
|
|
@ -89,7 +89,7 @@ constexpr int kStridedSliceInputNum = 1;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void StridedSlice::ApplyNewAxisMask() {
|
|
|
|
|
for (int i = 0; i < new_axis_mask_.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < new_axis_mask_.size(); i++) {
|
|
|
|
|
if (new_axis_mask_.at(i)) {
|
|
|
|
|
ndim_ += 1;
|
|
|
|
|
in_shape_.insert(in_shape_.begin() + i, 1);
|
|
|
|
@ -112,7 +112,7 @@ void StridedSlice::ApplyNewAxisMask() {
|
|
|
|
|
std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
|
|
|
|
|
auto old_out_shape = out_shape;
|
|
|
|
|
out_shape.clear();
|
|
|
|
|
for (int i = 0; i < shrink_axis_mask_.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < shrink_axis_mask_.size(); i++) {
|
|
|
|
|
if (shrink_axis_mask_.at(i)) {
|
|
|
|
|
ends_.at(i) = begins_.at(i) + 1;
|
|
|
|
|
strides_.at(i) = 1;
|
|
|
|
@ -120,7 +120,7 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
|
|
|
|
|
out_shape.emplace_back(old_out_shape.at(i));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int i = shrink_axis_mask_.size(); i < old_out_shape.size(); i++) {
|
|
|
|
|
for (size_t i = shrink_axis_mask_.size(); i < old_out_shape.size(); i++) {
|
|
|
|
|
out_shape.emplace_back(old_out_shape.at(i));
|
|
|
|
|
}
|
|
|
|
|
return out_shape;
|
|
|
|
@ -128,7 +128,7 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
|
|
|
|
|
|
|
|
|
|
/*only one bit will be used if multiple bits are true.*/
|
|
|
|
|
void StridedSlice::ApplyEllipsisMask() {
|
|
|
|
|
for (int i = 0; i < ellipsis_mask_.size(); i++) {
|
|
|
|
|
for (size_t i = 0; i < ellipsis_mask_.size(); i++) {
|
|
|
|
|
if (ellipsis_mask_.at(i)) {
|
|
|
|
|
begins_.at(i) = 0;
|
|
|
|
|
ends_.at(i) = in_shape_.at(i);
|
|
|
|
@ -204,7 +204,7 @@ int StridedSlice::InferShape(std::vector<lite::tensor::Tensor *> inputs, std::ve
|
|
|
|
|
|
|
|
|
|
output_shape.clear();
|
|
|
|
|
output_shape.resize(in_shape_.size());
|
|
|
|
|
for (int i = 0; i < in_shape_.size(); i++) {
|
|
|
|
|
for (int i = 0; i < static_cast<int>(in_shape_.size()); i++) {
|
|
|
|
|
if (i < ndim_ && new_axis_mask_.at(i)) {
|
|
|
|
|
output_shape.at(i) = 1;
|
|
|
|
|
} else {
|
|
|
|
|