diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index b034b618ef..493e12e8f5 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -37,8 +37,9 @@ #include "dataset/kernels/image/resize_bilinear_op.h" #include "dataset/kernels/image/resize_op.h" #include "dataset/kernels/image/uniform_aug_op.h" -#include "dataset/kernels/data/type_cast_op.h" #include "dataset/kernels/data/fill_op.h" +#include "dataset/kernels/data/slice_op.h" +#include "dataset/kernels/data/type_cast_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h" #include "dataset/engine/datasetops/source/io_block.h" @@ -369,6 +370,37 @@ void bindTensorOps2(py::module *m) { *m, "FillOp", "Tensor operation to return tensor filled with same value as input fill value.") .def(py::init>()); + (void)py::class_>(*m, "SliceOp", "") + .def(py::init()) + .def(py::init([](const py::list &py_list) { + std::vector c_list; + for (auto l : py_list) { + if (!l.is_none()) { + c_list.push_back(py::reinterpret_borrow(l)); + } + } + return std::make_shared(c_list); + })) + .def(py::init([](const py::tuple &py_slice) { + if (py_slice.size() != 3) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); + } + Slice c_slice; + if (!py_slice[0].is_none() && !py_slice[1].is_none() && !py_slice[2].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1]), + py::reinterpret_borrow(py_slice[2])); + } else if (py_slice[0].is_none() && py_slice[2].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[1])); + } else if (!py_slice[0].is_none() && !py_slice[1].is_none()) { + c_slice = Slice(py::reinterpret_borrow(py_slice[0]), py::reinterpret_borrow(py_slice[1])); + } + + if (!c_slice.valid()) { + THROW_IF_ERROR(Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Wrong slice object")); + } + return std::make_shared(c_slice); + })); + (void)py::class_>( *m, "RandomRotationOp", "Tensor operation to apply RandomRotation." diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc index 539c7adabb..2b587a115a 100644 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ b/mindspore/ccsrc/dataset/core/tensor.cc @@ -916,6 +916,61 @@ Status Tensor::CopyLastDimAt(const std::shared_ptr &src, const std::vect CHECK_FAIL_RETURN_UNEXPECTED(memcpy_s(dst_addr, len, src_addr, len) == 0, "memcpy error"); return Status::OK(); } +Status Tensor::Slice(std::shared_ptr *out, const std::vector &indices) { + CHECK_FAIL_RETURN_UNEXPECTED(shape_.Rank() == 1, "Currently Slice work with rank 1 tensors only."); + CHECK_FAIL_RETURN_UNEXPECTED(!indices.empty(), "Indices are empty, generated tensor would be empty."); + if (type_.IsNumeric()) { + return SliceNumeric(out, indices); + } else { + return SliceString(out, indices); + } +} +Status Tensor::SliceNumeric(std::shared_ptr *out, const std::vector &indices) { + RETURN_IF_NOT_OK( + CreateTensor(out, TensorImpl::kFlexible, TensorShape({static_cast(indices.size())}), type_)); + (*out)->GetMutableBuffer(); + dsize_t out_index = 0; + dsize_t dim_length = shape_[0]; + dsize_t type_size = type_.SizeInBytes(); + dsize_t src_start = handleNeg(indices[0], dim_length); + uchar *dst_addr = (*out)->data_; + dsize_t count = 1; + + for (dsize_t i = 0; i < indices.size(); i++) { + dsize_t cur_index = handleNeg(indices[i], dim_length); + CHECK_FAIL_RETURN_UNEXPECTED( + cur_index >= 0 && cur_index < dim_length, + "Index " + std::to_string(indices[i]) + " is out of bounds [0," + std::to_string(dim_length) + ")"); + if (i < indices.size() - 1) { + dsize_t next_index = handleNeg(indices[i + 1], dim_length); + if (next_index == cur_index + 1) { + count++; + continue; + } + } + memcpy_s(dst_addr + out_index * type_size, (*out)->SizeInBytes(), data_ + src_start * type_size, count * type_size); + out_index += count; + if (i < indices.size() - 1) { + src_start = handleNeg(indices[i + 1], dim_length); // next index + } + count = 1; + } + return Status::OK(); +} +Status Tensor::SliceString(std::shared_ptr *out, const std::vector &indices) { + dsize_t dim_length = shape_[0]; + std::vector strings; + for (dsize_t index : indices) { + dsize_t cur_index = handleNeg(index, dim_length); + CHECK_FAIL_RETURN_UNEXPECTED( + cur_index >= 0 && cur_index < dim_length, + "Index " + std::to_string(index) + " is out of bounds [0," + std::to_string(dim_length) + ")"); + std::string_view sv; + GetItemAt(&sv, {cur_index}); + strings.emplace_back(sv); + } + return CreateTensor(out, strings); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/core/tensor.h b/mindspore/ccsrc/dataset/core/tensor.h index 032a9cafc9..2953b7df81 100644 --- a/mindspore/ccsrc/dataset/core/tensor.h +++ b/mindspore/ccsrc/dataset/core/tensor.h @@ -347,6 +347,22 @@ class Tensor { return ss.str(); } + // Handle negative indices. + static inline dsize_t handleNeg(dsize_t index, dsize_t length) { return (index < 0) ? (index + length) : index; } + + // Slice tensor bases on the given indicies. Copy the sliced data into out tensor. Only rank1 tensors are supported. + // Based on the type of tensor, SliceNumeric or SliceString will be called + // @param out Tensor + // @param indices vector of indices + // @return Status error code + Status Slice(std::shared_ptr *out, const std::vector &indices); + + // Slice numeric tensors. + Status SliceNumeric(std::shared_ptr *out, const std::vector &indices); + + // Slice string tensors + Status SliceString(std::shared_ptr *out, const std::vector &indices); + // Constructs numpy array from input tensor // @param data this data is the location of python data // @return Status code diff --git a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt index 8c03b300ee..53a1ea6151 100644 --- a/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/data/CMakeLists.txt @@ -5,4 +5,5 @@ add_library(kernels-data OBJECT one_hot_op.cc type_cast_op.cc to_float16_op.cc - fill_op.cc) + fill_op.cc + slice_op.cc) diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc b/mindspore/ccsrc/dataset/kernels/data/slice_op.cc new file mode 100644 index 0000000000..c90b76b9d9 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/data/slice_op.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/kernels/data/slice_op.h" + +#include "dataset/core/tensor.h" +#include "dataset/kernels/data/data_utils.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +Status SliceOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Rank() == 1, "SliceOp supports 1D Tensors only for now."); + + // if `all` flag is true, output is just the input. + if (all_) { + *output = input; + return Status::OK(); + } + + // if slice object was provided, indices should be empty. Generate indices from the slice object. + if (slice_.valid() && indices_.empty()) { + dsize_t len = input->shape()[0]; + indices_ = slice_.Indices(len); + return input->Slice(output, indices_); + } + + // if indices are not empty, slices should be invalid, use indices_ to slice + if (!indices_.empty() && !slice_.valid()) { + return input->Slice(output, indices_); + } + RETURN_STATUS_UNEXPECTED("The indexing parameters are invalid"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.h b/mindspore/ccsrc/dataset/kernels/data/slice_op.h new file mode 100644 index 0000000000..cb1318dcba --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/data/slice_op.h @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_KERNELS_DATA_SLICE_OP_H_ +#define DATASET_KERNELS_DATA_SLICE_OP_H_ + +#include +#include +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class Slice { + public: + Slice() : start_(0), stop_(0), step_(0) {} + Slice(dsize_t start, dsize_t stop, dsize_t step) : start_(start), stop_(stop), step_(step) {} + Slice(dsize_t start, dsize_t stop) : start_(start), stop_(stop), step_(1) {} + explicit Slice(dsize_t stop) : start_(0), stop_(stop), step_(1) {} + + std::vector Indices(dsize_t length) { + std::vector indices; + dsize_t index = std::min(Tensor::handleNeg(start_, length), length); + dsize_t end_index = std::min(Tensor::handleNeg(stop_, length), length); + if (step_ > 0) { + for (; index < end_index; index += step_) { + indices.push_back(index); + } + } else { + for (; index > end_index; index += step_) { + indices.push_back(index); + } + } + return indices; + } + + bool valid() { return !(start_ == 0 && stop_ == 0 && step_ == 0); } + + dsize_t start_; + dsize_t stop_; + dsize_t step_; +}; + +class SliceOp : public TensorOp { + public: + explicit SliceOp(std::vector indices) : indices_(std::move(indices)) {} + explicit SliceOp(Slice slice) : slice_(slice) {} + explicit SliceOp(bool all) : all_(all) {} + + ~SliceOp() override = default; + + void Print(std::ostream &out) const override { out << "SliceOp"; } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + private: + // only on of the following will be valid + // given indices to slice the Tensor. Empty vector if invalid. + std::vector indices_; + // Slice object. All start, stop and step are 0 if invalid. + Slice slice_; + // Flag to read all indcies in the dim. + bool all_ = false; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_KERNELS_DATA_ONE_HOT_OP_H_ diff --git a/mindspore/dataset/transforms/c_transforms.py b/mindspore/dataset/transforms/c_transforms.py index 8f301f196e..fd9fb12d6a 100644 --- a/mindspore/dataset/transforms/c_transforms.py +++ b/mindspore/dataset/transforms/c_transforms.py @@ -17,7 +17,8 @@ This module c_transforms provides common operations, including OneHotOp and Type """ import numpy as np import mindspore._c_dataengine as cde -from .validators import check_num_classes, check_de_type, check_fill_value + +from .validators import check_num_classes, check_de_type, check_fill_value, check_slice_op from ..core.datatypes import mstype_to_detype @@ -64,3 +65,46 @@ class TypeCast(cde.TypeCastOp): data_type = mstype_to_detype(data_type) self.data_type = str(data_type) super().__init__(data_type) + + +class Slice(cde.SliceOp): + """ + Slice operation to extract a tensor out using the given n slices. + + The functionality of Slice is similar to NumPy indexing feature. + + (Currently only rank 1 Tensors are supported) + + Args: + *slices: Maximum n number of objects to slice a tensor of rank n. + One object in slices can be one of: + 1. int: slice this index only. Negative index is supported. + 2. slice object: slice the generated indices from the slice object. Similar to `start:stop:step`. + 3. None: slice the whole dimension. Similar to `:` in python indexing. + 4. Ellipses ...: slice all dimensions between the two slices. + Examples: + >>> # Data before + >>> # | col | + >>> # +---------+ + >>> # | [1,2,3] | + >>> # +---------| + >>> data = data.map(operations=Slice(slice(1,3))) # slice indices 1 and 2 only + >>> # Data after + >>> # | col | + >>> # +------------+ + >>> # | [1,2] | + >>> # +------------| + """ + + @check_slice_op + def __init__(self, *slices): + dim0 = slices[0] + if isinstance(dim0, int): + dim0 = [dim0] + elif dim0 is None: + dim0 = True + elif isinstance(dim0, slice): + dim0 = (dim0.start, dim0.stop, dim0.step) + elif dim0 is Ellipsis: + dim0 = True + super().__init__(dim0) diff --git a/mindspore/dataset/transforms/validators.py b/mindspore/dataset/transforms/validators.py index a7eb589cd7..0ca0d60412 100644 --- a/mindspore/dataset/transforms/validators.py +++ b/mindspore/dataset/transforms/validators.py @@ -15,6 +15,7 @@ """Validators for TensorOps. """ from functools import wraps + from mindspore._c_expression import typing # POS_INT_MIN is used to limit values from starting from 0 @@ -195,3 +196,20 @@ def check_de_type(method): return method(self, **kwargs) return new_method + + +def check_slice_op(method): + """Wrapper method to check the parameters of slice.""" + + @wraps(method) + def new_method(self, *args): + for i, arg in enumerate(args): + if arg is not None and arg is not Ellipsis and not isinstance(arg, (int, slice, list)): + raise TypeError("Indexing of dim " + str(i) + "is not of valid type") + if isinstance(arg, list): + for a in arg: + if not isinstance(a, int): + raise TypeError("Index " + a + " is not an int") + return method(self, *args) + + return new_method diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index 6c7402c6bb..d47d22fb9c 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -28,17 +28,13 @@ using namespace mindspore::dataset; namespace py = pybind11; - class MindDataTestTensorDE : public UT::Common { public: - MindDataTestTensorDE() {} + MindDataTestTensorDE() {} - void SetUp() { - GlobalInit(); - } + void SetUp() { GlobalInit(); } }; - TEST_F(MindDataTestTensorDE, Basics) { std::shared_ptr t = std::make_shared(TensorShape({2, 3}), DataType(DataType::DE_UINT64)); ASSERT_TRUE((t->AllocateBuffer(t->SizeInBytes())).IsOk()); @@ -167,8 +163,7 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { // Test the bug of Tensor::ToString will exec failed for Tensor which store bool values TEST_F(MindDataTestTensorDE, BoolTensor) { - std::shared_ptr t = std::make_shared(TensorShape({2}), - DataType(DataType::DE_BOOL)); + std::shared_ptr t = std::make_shared(TensorShape({2}), DataType(DataType::DE_BOOL)); t->SetItemAt({0}, true); t->SetItemAt({1}, true); std::string out = t->ToString(); @@ -255,14 +250,19 @@ void checkCvMat(TensorShape shape, DataType type) { } else { ASSERT_EQ(m.size[0], shape[0]); } - if (shape.Rank() == 3) { ASSERT_EQ(m.channels(), shape[2]); } + if (shape.Rank() == 3) { + ASSERT_EQ(m.channels(), shape[2]); + } ASSERT_EQ(m.dims, 2); ASSERT_EQ(m.size.dims(), 2); - if (shape.Rank() > 0) { ASSERT_EQ(m.rows, shape[0]); } - if (shape.Rank() > 1) { ASSERT_EQ(m.cols, shape[1]); } + if (shape.Rank() > 0) { + ASSERT_EQ(m.rows, shape[0]); + } + if (shape.Rank() > 1) { + ASSERT_EQ(m.cols, shape[1]); + } } else { - for (dsize_t i = 0; i < shape.Rank(); i++) - ASSERT_EQ(m.size[static_cast(i)], shape[i]); + for (dsize_t i = 0; i < shape.Rank(); i++) ASSERT_EQ(m.size[static_cast(i)], shape[i]); ASSERT_EQ(m.dims, shape.Rank()); ASSERT_EQ(m.size.dims(), shape.Rank()); ASSERT_EQ(m.rows, -1); @@ -394,3 +394,16 @@ TEST_F(MindDataTestTensorDE, TensorIterator) { } ASSERT_TRUE(ctr == 6); } + +TEST_F(MindDataTestTensorDE, TensorSlice) { + std::shared_ptr t; + Tensor::CreateTensor(&t, std::vector{0, 1, 2, 3, 4}); + std::shared_ptr t2; + auto x = std::vector{0, 3, 4}; + std::shared_ptr expected; + Tensor::CreateTensor(&expected, x); + t->Slice(&t2, x); + ASSERT_EQ(*t2, *expected); + t->Slice(&t2, std::vector{0, 1, 2, 3, 4}); + ASSERT_EQ(*t2, *t); +} diff --git a/tests/ut/python/dataset/test_slice_op.py b/tests/ut/python/dataset/test_slice_op.py new file mode 100644 index 0000000000..10038b9e2b --- /dev/null +++ b/tests/ut/python/dataset/test_slice_op.py @@ -0,0 +1,211 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing TypeCast op in DE +""" +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.transforms.c_transforms as ops + + +def slice_compare(array, indexing): + data = ds.NumpySlicesDataset([array]) + array = np.array(array) + data = data.map(operations=ops.Slice(indexing)) + for d in data: + if indexing is None: + array = array[:] + else: + array = array[indexing] + np.testing.assert_array_equal(array, d[0]) + + +def test_slice_all(): + slice_compare([1, 2, 3, 4, 5], None) + slice_compare([1, 2, 3, 4, 5], ...) + + +def test_slice_single_index(): + slice_compare([1, 2, 3, 4, 5], 0) + slice_compare([1, 2, 3, 4, 5], 4) + slice_compare([1, 2, 3, 4, 5], 2) + slice_compare([1, 2, 3, 4, 5], -1) + slice_compare([1, 2, 3, 4, 5], -5) + slice_compare([1, 2, 3, 4, 5], -3) + + +def test_slice_list_index(): + slice_compare([1, 2, 3, 4, 5], [0, 1, 4]) + slice_compare([1, 2, 3, 4, 5], [4, 1, 0]) + slice_compare([1, 2, 3, 4, 5], [-1, 1, 0]) + slice_compare([1, 2, 3, 4, 5], [-1, -4, -2]) + slice_compare([1, 2, 3, 4, 5], [3, 3, 3]) + slice_compare([1, 2, 3, 4, 5], [1, 1, 1, 1, 1]) + + +def test_slice_slice_obj_2s(): + slice_compare([1, 2, 3, 4, 5], slice(0, 2)) + slice_compare([1, 2, 3, 4, 5], slice(2, 4)) + slice_compare([1, 2, 3, 4, 5], slice(4, 10)) + + +def test_slice_slice_obj_1s(): + slice_compare([1, 2, 3, 4, 5], slice(1)) + slice_compare([1, 2, 3, 4, 5], slice(4)) + slice_compare([1, 2, 3, 4, 5], slice(10)) + + +def test_slice_slice_obj_3s(): + slice_compare([1, 2, 3, 4, 5], slice(0, 2, 1)) + slice_compare([1, 2, 3, 4, 5], slice(0, 4, 1)) + slice_compare([1, 2, 3, 4, 5], slice(0, 10, 1)) + slice_compare([1, 2, 3, 4, 5], slice(0, 5, 2)) + slice_compare([1, 2, 3, 4, 5], slice(0, 2, 2)) + slice_compare([1, 2, 3, 4, 5], slice(0, 1, 2)) + slice_compare([1, 2, 3, 4, 5], slice(4, 5, 1)) + slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3)) + + +def test_slice_slice_obj_3s_double(): + slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1)) + slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1)) + slice_compare([1., 2., 3., 4., 5.], slice(0, 10, 1)) + slice_compare([1., 2., 3., 4., 5.], slice(0, 5, 2)) + slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 2)) + slice_compare([1., 2., 3., 4., 5.], slice(0, 1, 2)) + slice_compare([1., 2., 3., 4., 5.], slice(4, 5, 1)) + slice_compare([1., 2., 3., 4., 5.], slice(2, 5, 3)) + + +def test_slice_slice_obj_neg(): + slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -1)) + slice_compare([1, 2, 3, 4, 5], slice(-1)) + slice_compare([1, 2, 3, 4, 5], slice(-2)) + slice_compare([1, 2, 3, 4, 5], slice(-1, -5, -2)) + slice_compare([1, 2, 3, 4, 5], slice(-5, -1, 2)) + slice_compare([1, 2, 3, 4, 5], slice(-5, -1)) + + +def test_slice_exceptions(): + with pytest.raises(RuntimeError) as info: + slice_compare([1, 2, 3, 4, 5], 5) + assert "Index 5 is out of bounds [0,5)" in str(info.value) + + with pytest.raises(RuntimeError) as info: + slice_compare([1, 2, 3, 4, 5], slice(0)) + assert "Indices are empty, generated tensor would be empty." in str(info.value) + + with pytest.raises(RuntimeError) as info: + slice_compare([1, 2, 3, 4, 5], slice(5, 10, 1)) + assert "Indices are empty, generated tensor would be empty." in str(info.value) + + with pytest.raises(RuntimeError) as info: + slice_compare([1, 2, 3, 4, 5], slice(-1, -5, 1)) + assert "Indices are empty, generated tensor would be empty." in str(info.value) + + +def test_slice_all_str(): + slice_compare([b"1", b"2", b"3", b"4", b"5"], None) + slice_compare([b"1", b"2", b"3", b"4", b"5"], ...) + + +def test_slice_single_index_str(): + slice_compare([b"1", b"2", b"3", b"4", b"5"], 0) + slice_compare([b"1", b"2", b"3", b"4", b"5"], 4) + slice_compare([b"1", b"2", b"3", b"4", b"5"], 2) + slice_compare([b"1", b"2", b"3", b"4", b"5"], -1) + slice_compare([b"1", b"2", b"3", b"4", b"5"], -5) + slice_compare([b"1", b"2", b"3", b"4", b"5"], -3) + + +def test_slice_list_index_str(): + slice_compare([b"1", b"2", b"3", b"4", b"5"], [0, 1, 4]) + slice_compare([b"1", b"2", b"3", b"4", b"5"], [4, 1, 0]) + slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1, 1, 0]) + slice_compare([b"1", b"2", b"3", b"4", b"5"], [-1, -4, -2]) + slice_compare([b"1", b"2", b"3", b"4", b"5"], [3, 3, 3]) + slice_compare([b"1", b"2", b"3", b"4", b"5"], [1, 1, 1, 1, 1]) + + +def test_slice_slice_obj_2s_str(): + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 4)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 10)) + + +def test_slice_slice_obj_1s_str(): + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(10)) + + +def test_slice_slice_obj_3s_str(): + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 4, 1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 10, 1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 5, 2)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 2, 2)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0, 1, 2)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(4, 5, 1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(2, 5, 3)) + + +def test_slice_slice_obj_neg_str(): + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-2)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, -2)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1, 2)) + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-5, -1)) + + +def test_slice_exceptions_str(): + with pytest.raises(RuntimeError) as info: + slice_compare([b"1", b"2", b"3", b"4", b"5"], 5) + assert "Index 5 is out of bounds [0,5)" in str(info.value) + + with pytest.raises(RuntimeError) as info: + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(0)) + assert "Indices are empty, generated tensor would be empty." in str(info.value) + + with pytest.raises(RuntimeError) as info: + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(5, 10, 1)) + assert "Indices are empty, generated tensor would be empty." in str(info.value) + + with pytest.raises(RuntimeError) as info: + slice_compare([b"1", b"2", b"3", b"4", b"5"], slice(-1, -5, 1)) + assert "Indices are empty, generated tensor would be empty." in str(info.value) + + +if __name__ == "__main__": + test_slice_all() + test_slice_single_index() + test_slice_list_index() + test_slice_slice_obj_3s() + test_slice_slice_obj_2s() + test_slice_slice_obj_1s() + test_slice_slice_obj_neg() + test_slice_exceptions() + test_slice_slice_obj_3s_double() + test_slice_all_str() + test_slice_single_index_str() + test_slice_list_index_str() + test_slice_slice_obj_3s_str() + test_slice_slice_obj_2s_str() + test_slice_slice_obj_1s_str() + test_slice_slice_obj_neg_str() + test_slice_exceptions_str()