|
|
|
@ -223,6 +223,8 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector<ConstGeTensorPtr
|
|
|
|
|
|
|
|
|
|
vector<int64_t> orig_begin_vec, orig_end_vec, orig_stride_vec;
|
|
|
|
|
GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec);
|
|
|
|
|
// calculate begin_mask & end_mask by ellipsis_mask
|
|
|
|
|
ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec);
|
|
|
|
|
auto begin_dim_num = orig_begin_vec.size();
|
|
|
|
|
auto min_dim = x_dims_num > begin_dim_num ? begin_dim_num : x_dims_num;
|
|
|
|
|
for (size_t i = 0; i < x_dims.size(); ++i) {
|
|
|
|
@ -281,6 +283,38 @@ void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_ten
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num,
|
|
|
|
|
const vector<int64_t> &x_dims, vector<int64_t> &orig_begin_vec,
|
|
|
|
|
vector<int64_t> &orig_end_vec, vector<int64_t> &orig_stride_vec) {
|
|
|
|
|
|
|
|
|
|
if (attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) != 0) {
|
|
|
|
|
auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK);
|
|
|
|
|
auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK);
|
|
|
|
|
if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) {
|
|
|
|
|
begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() -1));
|
|
|
|
|
attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask;
|
|
|
|
|
}
|
|
|
|
|
if (end_mask != 0 && x_dims_num != orig_end_vec.size()) {
|
|
|
|
|
end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() -1));
|
|
|
|
|
attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask;
|
|
|
|
|
}
|
|
|
|
|
for (auto i = 0; i < x_dims_num; ++i) {
|
|
|
|
|
bool ellipsis_mask_flag = attr_value_map_.at(STRIDE_SLICE_ATTR_ELLIPSIS_MASK) & (kMaskBitLeftUnit << i);
|
|
|
|
|
if (ellipsis_mask_flag) {
|
|
|
|
|
auto ellipsis_dim = i;
|
|
|
|
|
orig_begin_vec[i] = 0;
|
|
|
|
|
orig_end_vec[i] = x_dims.at(i);
|
|
|
|
|
orig_stride_vec[i] = 1;
|
|
|
|
|
if (auto j = 0; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) {
|
|
|
|
|
orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0);
|
|
|
|
|
orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j));
|
|
|
|
|
orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &end_i, int64_t &dim_i) const {
|
|
|
|
|
auto i_temp = static_cast<uint32_t>(i);
|
|
|
|
|
bool begin_mask_flag = (attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) & (kMaskBitLeftUnit << i_temp));
|
|
|
|
@ -302,10 +336,6 @@ Status StridedSliceKernel::MaskCal(const size_t i, int64_t &begin_i, int64_t &en
|
|
|
|
|
} else {
|
|
|
|
|
end_i = (end_i < 0 ? (dim_i + end_i) : end_i);
|
|
|
|
|
}
|
|
|
|
|
if (ellipsis_mask_flag) {
|
|
|
|
|
begin_i = 0;
|
|
|
|
|
end_i = dim_i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
@ -316,9 +346,11 @@ Status StridedSliceKernel::StrideCal(const int64_t x_dims_i, int64_t &begin_i, i
|
|
|
|
|
stride_i = kDefaultStrideSize;
|
|
|
|
|
} else if (stride_i < 0) {
|
|
|
|
|
stride_i = -stride_i;
|
|
|
|
|
if (begin_i < 0 && end_i < 0) {
|
|
|
|
|
begin_i = x_dims_i - begin_i - 1;
|
|
|
|
|
end_i = x_dims_i - end_i - 1;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (end_i > x_dims_i) {
|
|
|
|
|
end_i = x_dims_i;
|
|
|
|
|