diff --git a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc b/mindspore/ccsrc/dataset/kernels/data/slice_op.cc index 8401a33511..2eebf26e84 100644 --- a/mindspore/ccsrc/dataset/kernels/data/slice_op.cc +++ b/mindspore/ccsrc/dataset/kernels/data/slice_op.cc @@ -33,8 +33,8 @@ Status SliceOp::Compute(const std::shared_ptr &input, std::shared_ptrshape()[0]; - indices_ = slice_.Indices(len); - return input->Slice(output, indices_); + std::vector indices = slice_.Indices(len); + return input->Slice(output, indices); } // if indices are not empty, slices should be invalid, use indices_ to slice diff --git a/tests/ut/python/dataset/test_slice_op.py b/tests/ut/python/dataset/test_slice_op.py index fd5e8baac9..6e81133a2a 100644 --- a/tests/ut/python/dataset/test_slice_op.py +++ b/tests/ut/python/dataset/test_slice_op.py @@ -80,6 +80,22 @@ def test_slice_slice_obj_3s(): slice_compare([1, 2, 3, 4, 5], slice(2, 5, 3)) +def test_slice_multiple_rows(): + dataset = [[1, 2], [3, 4, 5], [1], [1, 2, 3, 4, 5, 6, 7]] + + def gen(): + for row in dataset: + yield (np.array(row),) + + data = ds.GeneratorDataset(gen, column_names=["col"]) + indexing = slice(0, 4) + data = data.map(operations=ops.Slice(indexing)) + for i, d in enumerate(data): + array = np.array(dataset[i]) + array = array[indexing] + np.testing.assert_array_equal(array, d[0]) + + def test_slice_slice_obj_3s_double(): slice_compare([1., 2., 3., 4., 5.], slice(0, 2, 1)) slice_compare([1., 2., 3., 4., 5.], slice(0, 4, 1)) @@ -217,3 +233,4 @@ if __name__ == "__main__": test_slice_slice_obj_1s_str() test_slice_slice_obj_neg_str() test_slice_exceptions_str() + test_slice_multiple_rows()