|
|
|
@ -390,7 +390,6 @@ def filter_func_Partial_0(col1, col2, col3, col4):
|
|
|
|
|
|
|
|
|
|
# test with row_data_buffer > 1
|
|
|
|
|
def test_filter_by_generator_Partial0():
|
|
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg')
|
|
|
|
|
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
|
|
|
|
|
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])
|
|
|
|
|
dataset_zip = ds.zip((dataset1, dataset2))
|
|
|
|
@ -404,7 +403,6 @@ def test_filter_by_generator_Partial0():
|
|
|
|
|
|
|
|
|
|
# test with row_data_buffer > 1
|
|
|
|
|
def test_filter_by_generator_Partial1():
|
|
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg')
|
|
|
|
|
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
|
|
|
|
|
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])
|
|
|
|
|
dataset_zip = ds.zip((dataset1, dataset2))
|
|
|
|
@ -419,7 +417,6 @@ def test_filter_by_generator_Partial1():
|
|
|
|
|
|
|
|
|
|
# test with row_data_buffer > 1
|
|
|
|
|
def test_filter_by_generator_Partial2():
|
|
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg')
|
|
|
|
|
dataset1 = ds.GeneratorDataset(source=generator_mc_p0(), column_names=["col1", "col2"])
|
|
|
|
|
dataset2 = ds.GeneratorDataset(source=generator_mc_p1(), column_names=["col3", "col4"])
|
|
|
|
|
|
|
|
|
@ -454,7 +451,6 @@ def generator_big(maxid=20):
|
|
|
|
|
|
|
|
|
|
# test with row_data_buffer > 1
|
|
|
|
|
def test_filter_by_generator_Partial():
|
|
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg')
|
|
|
|
|
dataset = ds.GeneratorDataset(source=generator_mc(99), column_names=["col1", "col2"])
|
|
|
|
|
dataset_s = dataset.shuffle(4)
|
|
|
|
|
dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1)
|
|
|
|
@ -473,7 +469,6 @@ def filter_func_cifar(col1, col2):
|
|
|
|
|
# test with cifar10
|
|
|
|
|
def test_filte_case_dataset_cifar10():
|
|
|
|
|
DATA_DIR_10 = "../data/dataset/testCifar10Data"
|
|
|
|
|
ds.config.load('../data/dataset/declient_filter.cfg')
|
|
|
|
|
dataset_c = ds.Cifar10Dataset(dataset_dir=DATA_DIR_10, num_samples=100000, shuffle=False)
|
|
|
|
|
dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1)
|
|
|
|
|
for item in dataset_f1.create_dict_iterator():
|
|
|
|
|