@ -17,11 +17,13 @@ Testing cache operator with non-mappable datasets
import os
import itertools
import numpy as np
import pytest
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.text as text
import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.py_transforms as py_vision
from mindspore import log as logger
DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"]
@ -41,6 +43,9 @@ CLUE_DATA_DIR = '../data/dataset/testCLUE/afqmc/train.json'
CSV_DATA_DIR = '../data/dataset/testCSV/1.csv'
TEXT_FILE_DATA_DIR = "../data/dataset/testTextFileDataset/1.txt"
PYFUNC_DATA_DIR = ["../data/dataset/testPyfuncMap/data.data"]
PYFUNC_SCHEMA_DIR = "../data/dataset/testPyfuncMap/schema.json"
@ -1633,7 +1638,7 @@ def test_cache_nomap_clue2():
some_cache = ds.DatasetCache(session_id=session_id, size=0)
ds1 = ds.CLUEDataset(CLUE_DATA_DIR, task='AFQMC', usage='train', num_samples=2)
ds1 = ds1.map((lambda x: x), ["label"], cache=some_cache)
ds1 = ds1.map(py_vision.not_random(lambda x: x), ["label"], cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
@ -1710,7 +1715,7 @@ def test_cache_nomap_csv2():
ds1 = ds.CSVDataset(CSV_DATA_DIR, column_defaults=["1", "2", "3", "4"],
column_names=['col1', 'col2', 'col3', 'col4'], num_samples=2)
ds1 = ds1.map((lambda x: x), ["col1"], cache=some_cache)
ds1 = ds1.map(py_vision.not_random(lambda x: x), ["col1"], cache=some_cache)
num_epoch = 4
iter1 = ds1.create_dict_iterator(num_epochs=num_epoch, output_numpy=True)
@ -2124,6 +2129,139 @@ def test_cache_nomap_failure5():
logger.info('test_cache_nomap_failure5 Ended.\n')
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_lambda():
Test cache after map op with a python lambda function.
Only allowed if the lambda function is wrapped by 'pyvision.not_random', otherwise an error will be raised.
Map(lambda function1, lambda function2)
logger.info("Test cache nomap pyfunc lambda")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 12 records in it
data1 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
transforms = [py_vision.not_random(lambda x: x + x), py_vision.not_random(lambda x: x - 1)]
data1 = data1.map(operations=transforms, input_columns="col0", cache=some_cache)
num_iter = 0
for _ in data1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 12
other_cache = ds.DatasetCache(session_id=session_id, size=0)
ds2 = ds.TFRecordDataset(PYFUNC_DATA_DIR, PYFUNC_SCHEMA_DIR, shuffle=False)
ds2 = ds2.map(operations=[(lambda x: x + x)], input_columns=["col0"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_lambda Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_builtin():
Test cache after map op with a python builtin PyFunc.
An error will be raised if the builtin pyfunc containing random operation.
Map([builtin pyfunc1, builtin pyfunc2])
logger.info("Test cache nomap pyfunc builtin")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds1 = ds1.map(operations=[py_vision.Decode(), py_vision.ToTensor()], input_columns=["image"], cache=some_cache)
num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
other_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds2 = ds2.map(operations=[py_vision.Decode(), py_vision.RandomCrop(224), py_vision.ToTensor()],
input_columns=["image"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_builtin Ended.\n")
@pytest.mark.skipif(os.environ.get('RUN_CACHE_TEST') != 'TRUE', reason="Require to bring up cache server")
def test_cache_nomap_pyfunc_function():
Test cache after map op with a python customized function.
Only allowed if the function is decorated with 'py_vision.not_random', otherwise an error will be raised.
Map([function1, function2])
def not_random_func(x):
return np.ones(x.shape, dtype=x.dtype)
def normal_func(x):
return np.ones(x.shape, dtype=x.dtype)
logger.info("Test cache nomap pyfunc function")
if "SESSION_ID" in os.environ:
session_id = int(os.environ['SESSION_ID'])
raise RuntimeError("Testcase requires SESSION_ID environment variable")
some_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds1 = ds1.map(operations=[not_random_func, not_random_func], input_columns=["image"], cache=some_cache)
num_iter = 0
for _ in ds1.create_dict_iterator(num_epochs=1):
num_iter += 1
assert num_iter == 3
other_cache = ds.DatasetCache(session_id=session_id, size=0)
# This dataset has 3 records in it only
ds2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"])
ds2 = ds2.map(operations=[not_random_func, normal_func], input_columns=["image"], cache=other_cache)
with pytest.raises(RuntimeError) as e:
num_iter = 0
for _ in ds2.create_dict_iterator():
num_iter += 1
assert "MapNode containing random operation is not supported as a descendant of cache" in str(e.value)
logger.info("test_cache_nomap_pyfunc_function Ended.\n")
if __name__ == '__main__':
# This is just a list of tests, don't try to run these tests with 'python test_cache_nomap.py'
# since cache server is required to be brought up first
@ -2180,3 +2318,6 @@ if __name__ == '__main__':