hidden the dataset call of pipeline to train_from_dataset (#25834)

* hidden the explicit setting of dataset for pipeline training.
fix_copy_if_different
lilong12 5 years ago committed by GitHub
parent f132c2f40e
commit a07b62623e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1334,14 +1334,25 @@ class Executor(object):
fetch_info=None,
print_period=100,
fetch_handler=None):
if dataset is None:
raise RuntimeError("dataset is need and should be initialized")
if program._pipeline_opt is not None and program._pipeline_opt[
"sync_steps"] != -1:
# hack for paddlebox: sync_steps(-1) denotes paddlebox
thread = self._adjust_pipeline_resource(program._pipeline_opt,
dataset, thread)
if program._pipeline_opt is not None:
import paddle
if dataset is not None:
raise RuntimeError("dataset should be None for pipeline mode")
# The following fake dataset is created to call
# the _prepare_trainer api, and it is meaningless.
data_vars = []
for var in program.global_block().vars.values():
if var.is_data:
data_vars.append(var)
dataset = paddle.fluid.DatasetFactory().create_dataset(
'FileInstantDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['None'])
dataset.set_use_var(data_vars)
else:
if dataset is None:
raise RuntimeError("dataset is need and should be initialized")
dataset._prepare_to_run()

@ -186,18 +186,10 @@ class TestPipeline(unittest.TestCase):
data_loader.set_sample_generator(train_reader, batch_size=1)
place = fluid.CPUPlace()
# The following dataset is only used for the
# interface 'train_from_dataset'.
# And it has no actual meaning.
dataset = fluid.DatasetFactory().create_dataset('FileInstantDataset')
dataset.set_batch_size(1)
dataset.set_thread(1)
dataset.set_filelist(['/tmp/tmp_2.txt'])
dataset.set_use_var([image, label])
exe = fluid.Executor(place)
exe.run(startup_prog)
data_loader.start()
exe.train_from_dataset(main_prog, dataset, debug=debug)
exe.train_from_dataset(main_prog, debug=debug)
def test_pipeline(self):
self._run(False)

Loading…
Cancel
Save