modified: ge/host_kernels/strided_slice_kernel.cc

pull/339/head
zhaoxinxin 4 years ago
parent c7ee494caf
commit 2a5548b192

@ -291,11 +291,11 @@ void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num,
auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK); auto end_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK);
auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK); auto begin_mask = attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK);
if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) { if (begin_mask != 0 && x_dims_num != orig_begin_vec.size()) {
begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() -1)); begin_mask *= begin_mask * (kMaskBitLeftUnit << (x_dims_num - orig_begin_vec.size() - 1));
attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask; attr_value_map_.at(STRIDE_SLICE_ATTR_BEGIN_MASK) = begin_mask;
} }
if (end_mask != 0 && x_dims_num != orig_end_vec.size()) { if (end_mask != 0 && x_dims_num != orig_end_vec.size()) {
end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() -1)); end_mask *= end_mask * (kMaskBitLeftUnit << (x_dims_num - orig_end_vec.size() - 1));
attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask; attr_value_map_.at(STRIDE_SLICE_ATTR_END_MASK) = end_mask;
} }
for (auto i = 0; i < x_dims_num; ++i) { for (auto i = 0; i < x_dims_num; ++i) {
@ -306,7 +306,7 @@ void StridedSliceKernel::ExpandStrideWithEllipsisMask(const size_t x_dims_num,
orig_end_vec[i] = x_dims.at(i); orig_end_vec[i] = x_dims.at(i);
orig_stride_vec[i] = 1; orig_stride_vec[i] = 1;
if (orig_begin_vec.size() < x_dims_num) { if (orig_begin_vec.size() < x_dims_num) {
for (auto j = 0; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) { for (auto j = 1; j < (x_dims_num - orig_begin_vec.size() + 1); ++j) {
orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0); orig_begin_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 0);
orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j)); orig_end_vec.insert((orig_end_vec.begin() + ellipsis_dim + j), x_dims.at(ellipsis_dim +j));
orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1); orig_stride_vec.insert((orig_begin_vec.begin() + ellipsis_dim + j), 1);

Loading…
Cancel
Save