|
|
@ -97,9 +97,11 @@ class DatasetLoaderTestBase(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
|
|
def check_batch_number(self, place, randomize_batch_num=False):
|
|
|
|
def check_batch_number(self, place, randomize_batch_num=False):
|
|
|
|
main_prog, startup_prog, feeds = self.build_network()
|
|
|
|
main_prog, startup_prog, feeds = self.build_network()
|
|
|
|
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset(
|
|
|
|
if self.dataset_name == "QueueDataset":
|
|
|
|
self.dataset_name)
|
|
|
|
dataset = paddle.distributed.QueueDataset()
|
|
|
|
dataset.set_batch_size(BATCH_SIZE)
|
|
|
|
else:
|
|
|
|
|
|
|
|
dataset = paddle.distributed.InMemoryDataset()
|
|
|
|
|
|
|
|
dataset._set_batch_size(BATCH_SIZE)
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(place, fluid.CPUPlace):
|
|
|
|
if isinstance(place, fluid.CPUPlace):
|
|
|
|
file_num = 10
|
|
|
|
file_num = 10
|
|
|
@ -128,8 +130,8 @@ class DatasetLoaderTestBase(unittest.TestCase):
|
|
|
|
fake_reader(batch_num=BATCH_NUM + random_delta_batch_size[i]))
|
|
|
|
fake_reader(batch_num=BATCH_NUM + random_delta_batch_size[i]))
|
|
|
|
|
|
|
|
|
|
|
|
dataset.set_filelist(filelist)
|
|
|
|
dataset.set_filelist(filelist)
|
|
|
|
dataset.set_use_var(feeds)
|
|
|
|
dataset._set_use_var(feeds)
|
|
|
|
dataset.set_pipe_command("cat")
|
|
|
|
dataset._set_pipe_command("cat")
|
|
|
|
if self.dataset_name == 'InMemoryDataset':
|
|
|
|
if self.dataset_name == 'InMemoryDataset':
|
|
|
|
dataset.load_into_memory()
|
|
|
|
dataset.load_into_memory()
|
|
|
|
|
|
|
|
|
|
|
|