|
|
|
@ -30,14 +30,20 @@ constexpr int kStridedSliceInputNum = 1;
|
|
|
|
|
void StridedSlice::ApplyNewAxisMask() {
|
|
|
|
|
for (int i = 0; i < new_axis_mask_.size(); i++) {
|
|
|
|
|
if (new_axis_mask_.at(i)) {
|
|
|
|
|
updated_ndim_ += 1;
|
|
|
|
|
updated_in_shape_.insert(updated_in_shape_.begin() + i, 1);
|
|
|
|
|
updated_begins_.at(i) = 0;
|
|
|
|
|
updated_ends_.at(i) = 1;
|
|
|
|
|
updated_strides_.at(i) = 1;
|
|
|
|
|
updated_begins_.emplace_back(0);
|
|
|
|
|
updated_ends_.emplace_back(updated_in_shape_.at(updated_ndim_ - 1));
|
|
|
|
|
updated_strides_.emplace_back(1);
|
|
|
|
|
ndim_ += 1;
|
|
|
|
|
in_shape_.insert(in_shape_.begin() + i, 1);
|
|
|
|
|
begins_.at(i) = 0;
|
|
|
|
|
ends_.at(i) = 1;
|
|
|
|
|
strides_.at(i) = 1;
|
|
|
|
|
|
|
|
|
|
begins_.emplace_back(0);
|
|
|
|
|
ends_.emplace_back(in_shape_.at(ndim_ - 1));
|
|
|
|
|
strides_.emplace_back(1);
|
|
|
|
|
|
|
|
|
|
begins_mask_.at(i) = false;
|
|
|
|
|
ends_mask_.at(i) = false;
|
|
|
|
|
ellipsis_mask_.at(i) = false;
|
|
|
|
|
shrink_axis_mask_.at(i) = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -47,8 +53,8 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
|
|
|
|
|
out_shape.clear();
|
|
|
|
|
for (int i = 0; i < shrink_axis_mask_.size(); i++) {
|
|
|
|
|
if (shrink_axis_mask_.at(i)) {
|
|
|
|
|
updated_ends_.at(i) = updated_begins_.at(i) + 1;
|
|
|
|
|
updated_strides_.at(i) = 1;
|
|
|
|
|
ends_.at(i) = begins_.at(i) + 1;
|
|
|
|
|
strides_.at(i) = 1;
|
|
|
|
|
} else {
|
|
|
|
|
out_shape.emplace_back(old_out_shape.at(i));
|
|
|
|
|
}
|
|
|
|
@ -63,22 +69,26 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
|
|
|
|
|
void StridedSlice::ApplyEllipsisMask() {
|
|
|
|
|
for (int i = 0; i < ellipsis_mask_.size(); i++) {
|
|
|
|
|
if (ellipsis_mask_.at(i)) {
|
|
|
|
|
updated_begins_.at(i) = 0;
|
|
|
|
|
updated_ends_.at(i) = updated_in_shape_.at(i);
|
|
|
|
|
begins_.at(i) = 0;
|
|
|
|
|
ends_.at(i) = in_shape_.at(i);
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StridedSlice::ApplyBeginMask() {
|
|
|
|
|
for (int i = 0; i < ori_ndim_; i++) {
|
|
|
|
|
updated_begins_.at(i) = 0;
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
if (begins_mask_.at(i)) {
|
|
|
|
|
begins_.at(i) = 0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StridedSlice::ApplyEndMask() {
|
|
|
|
|
for (int i = 0; i < ori_ndim_; i++) {
|
|
|
|
|
updated_ends_.at(i) = 0;
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
if (ends_.at(i)) {
|
|
|
|
|
ends_.at(i) = in_shape_.at(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -88,7 +98,7 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
|
|
|
|
|
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
if (inputs.size() < kStridedSliceInputNum) {
|
|
|
|
|
if (inputs.size() != kStridedSliceInputNum) {
|
|
|
|
|
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
@ -97,28 +107,28 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
|
|
|
|
|
auto input_shape = input->shape();
|
|
|
|
|
std::vector<int> output_shape;
|
|
|
|
|
auto strided_slice_prim = this->primitive->value_as_StridedSlice();
|
|
|
|
|
updated_ndim_ = static_cast<int>(strided_slice_prim->begin()->size());
|
|
|
|
|
ori_ndim_ = updated_ndim_;
|
|
|
|
|
MS_ASSERT(updated_ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
|
|
|
|
|
MS_ASSERT(updated_ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
|
|
|
|
|
MS_ASSERT(updated_ndim_ == static_cast<int>(input_shape.size()));
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < updated_ndim_; i++) {
|
|
|
|
|
updated_in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
updated_begins_.emplace_back((*(strided_slice_prim->begin()))[i]);
|
|
|
|
|
updated_ends_.emplace_back((*(strided_slice_prim->end()))[i]);
|
|
|
|
|
updated_strides_.emplace_back((*(strided_slice_prim->stride()))[i]);
|
|
|
|
|
ndim_ = static_cast<int>(strided_slice_prim->begin()->size());
|
|
|
|
|
|
|
|
|
|
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
|
|
|
|
|
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
|
|
|
|
|
MS_ASSERT(ndim_ == static_cast<int>(input_shape.size()));
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
in_shape_.emplace_back(input_shape.at(i));
|
|
|
|
|
begins_.emplace_back((*(strided_slice_prim->begin()))[i]);
|
|
|
|
|
ends_.emplace_back((*(strided_slice_prim->end()))[i]);
|
|
|
|
|
strides_.emplace_back((*(strided_slice_prim->stride()))[i]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set all mask to original input shape
|
|
|
|
|
begins_mask_.resize(updated_ndim_);
|
|
|
|
|
ends_mask_.resize(updated_ndim_);
|
|
|
|
|
ellipsis_mask_.resize(updated_ndim_);
|
|
|
|
|
new_axis_mask_.resize(updated_ndim_);
|
|
|
|
|
shrink_axis_mask_.resize(updated_ndim_);
|
|
|
|
|
begins_mask_.resize(ndim_);
|
|
|
|
|
ends_mask_.resize(ndim_);
|
|
|
|
|
ellipsis_mask_.resize(ndim_);
|
|
|
|
|
new_axis_mask_.resize(ndim_);
|
|
|
|
|
shrink_axis_mask_.resize(ndim_);
|
|
|
|
|
|
|
|
|
|
// convert bit to vector
|
|
|
|
|
for (int i = 0; i < updated_ndim_; i++) {
|
|
|
|
|
for (int i = 0; i < ndim_; i++) {
|
|
|
|
|
begins_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->beginMask()) & (1 << i);
|
|
|
|
|
ends_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->endMask()) & (1 << i);
|
|
|
|
|
ellipsis_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->ellipsisMask()) & (1 << i);
|
|
|
|
@ -127,29 +137,17 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ApplyNewAxisMask();
|
|
|
|
|
ApplyNewAxisMask();
|
|
|
|
|
ApplyBeginMask();
|
|
|
|
|
ApplyEndMask();
|
|
|
|
|
ApplyEllipsisMask();
|
|
|
|
|
|
|
|
|
|
output_shape.resize(updated_in_shape_.size());
|
|
|
|
|
for (int i = 0; i < updated_in_shape_.size(); i++) {
|
|
|
|
|
if (i < ori_ndim_ && new_axis_mask_.at(i)) {
|
|
|
|
|
output_shape.clear();
|
|
|
|
|
output_shape.resize(in_shape_.size());
|
|
|
|
|
for (int i = 0; i < in_shape_.size(); i++) {
|
|
|
|
|
if (i < ndim_ && new_axis_mask_.at(i)) {
|
|
|
|
|
output_shape.at(i) = 1;
|
|
|
|
|
} else {
|
|
|
|
|
// begins and ends out of range handling
|
|
|
|
|
if (updated_begins_.at(i) >= updated_in_shape_.at(i) || updated_begins_.at(i) < -updated_in_shape_.at(i) ||
|
|
|
|
|
updated_ends_.at(i) < -updated_in_shape_.at(i) || updated_ends_.at(i) > updated_in_shape_.at(i)) {
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
updated_begins_.at(i) = updated_begins_.at(i) % updated_in_shape_.at(i);
|
|
|
|
|
updated_ends_.at(i) = updated_ends_.at(i) % updated_in_shape_.at(i);
|
|
|
|
|
|
|
|
|
|
if ((updated_ends_.at(i) <= updated_begins_.at(i) && updated_strides_.at(i) > 0) ||
|
|
|
|
|
(updated_ends_.at(i) >= updated_begins_.at(i) && updated_strides_.at(i) < 0)) {
|
|
|
|
|
output_shape.at(i) = 0;
|
|
|
|
|
} else {
|
|
|
|
|
output_shape.at(i) = 1 + (updated_ends_.at(i) - updated_begins_.at(i) - 1) / updated_strides_.at(i);
|
|
|
|
|
}
|
|
|
|
|
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|