|
|
|
@ -16,6 +16,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision
|
|
|
|
|
from util import save_and_check
|
|
|
|
|
|
|
|
|
|
import mindspore.dataset as ds
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
|
|
|
|
|
DATA_DIR_TF = ["../data/dataset/testTFTestAllTypes/test.data"]
|
|
|
|
@ -95,6 +96,141 @@ def test_tf_repeat_03():
|
|
|
|
|
assert num_iter == 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generator():
|
|
|
|
|
for i in range(3):
|
|
|
|
|
yield np.array([i]),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat1():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
data = data.repeat(3)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert i % 3 == d[0][0]
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 2 * 3 * 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat2():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat(1)
|
|
|
|
|
data = data.repeat(1)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert i % 3 == d[0][0]
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat3():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat(1)
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert i % 3 == d[0][0]
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 2 * 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat4():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
data = data.repeat(1)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert i % 3 == d[0][0]
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 2 * 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat5():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.batch(3)
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
data = data.repeat(3)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat6():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
data = data.batch(3)
|
|
|
|
|
data = data.repeat(3)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat7():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
data = data.repeat(3)
|
|
|
|
|
data = data.batch(3)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert np.array_equal(d[0], np.asarray([[0], [1], [2]]))
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat8():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.batch(2, drop_remainder=False)
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
data = data.repeat(3)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
if i % 2 == 0:
|
|
|
|
|
assert np.array_equal(d[0], np.asarray([[0], [1]]))
|
|
|
|
|
else:
|
|
|
|
|
assert np.array_equal(d[0], np.asarray([[2]]))
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 6 * 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat9():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat()
|
|
|
|
|
data = data.repeat(3)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert i % 3 == d[0][0]
|
|
|
|
|
if i == 10:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat10():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat(3)
|
|
|
|
|
data = data.repeat()
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert i % 3 == d[0][0]
|
|
|
|
|
if i == 10:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_nested_repeat11():
|
|
|
|
|
data = ds.GeneratorDataset(generator, ["data"])
|
|
|
|
|
data = data.repeat(2)
|
|
|
|
|
data = data.repeat(3)
|
|
|
|
|
data = data.repeat(4)
|
|
|
|
|
data = data.repeat(5)
|
|
|
|
|
|
|
|
|
|
for i, d in enumerate(data):
|
|
|
|
|
assert i % 3 == d[0][0]
|
|
|
|
|
|
|
|
|
|
assert sum([1 for _ in data]) == 2 * 3 * 4 * 5 * 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
logger.info("--------test tf repeat 01---------")
|
|
|
|
|
# test_repeat_01()
|
|
|
|
@ -104,4 +240,3 @@ if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
logger.info("--------test tf repeat 03---------")
|
|
|
|
|
test_tf_repeat_03()
|
|
|
|
|
|
|
|
|
|