Merge branch 'feature/rewrite_open_files' into feature/combine_open_files_and_double_buffer

guochaorong-patch-1
yuyang18 7 years ago
commit c5023bea75
No known key found for this signature in database
GPG Key ID: 6DFF29878217BE5F

File diff suppressed because it is too large Load Diff

@ -28,6 +28,7 @@ Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
Scanner::Scanner(const std::string &filename)
: stream_(new std::ifstream(filename)), parser_(*stream_) {
PADDLE_ENFORCE(static_cast<bool>(*stream_), "Cannot open file %s", filename);
Reset();
}

@ -20,6 +20,7 @@ from control_flow import BlockGuard
from ..layer_helper import LayerHelper
from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc
import sys
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
@ -532,10 +533,10 @@ def open_files(filenames,
shapes,
lod_levels,
dtypes,
thread_num=1,
thread_num=None,
buffer_size=None,
pass_num=1,
for_parallel=True):
is_test=None):
"""
Open files
@ -548,14 +549,15 @@ def open_files(filenames,
shapes(list): List of tuples which declaring data shapes.
lod_levels(list): List of ints which declaring data lod_level.
dtypes(list): List of strs which declaring data type.
thread_num(int): The maximal concurrent prefetch thread number.
buffer_size(int|None): The size of prefetch buffer. If it is setted None,
buffer size will be thread_num * 3.
Default: None
thread_num(None): Deprecated argument. It will be set by open_files
automatically.
buffer_size(None): Deprecated argument. It will be set by open_files
automatically.
pass_num(int): Number of passes to run.
for_parallel(Bool): Set it as True if you are going to run
subsequent operators in parallel.
Default: True
is_test(bool|None): Whether `open_files` used for testing or not. If it
is used for testing, the order of data generated is same as the file
order. Otherwise, it is not guaranteed the order of data is same
between every epoch. [Default: False].
Returns:
Variable: A Reader Variable via which we can get file data.
@ -567,15 +569,20 @@ def open_files(filenames,
'./data2.recordio'],
shapes=[(3,224,224), (1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=2,
buffer_size=2)
dtypes=['float32', 'int64'])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
"""
if buffer_size is None:
buffer_size = thread_num * 3
if thread_num is not None:
print >> sys.stderr, "thread_num parameter of open_files is " \
"deprecated. It will be ignored and set " \
"automatically by open_files "
if buffer_size is not None:
print >> sys.stderr, "buffer_size parameter of open_files is " \
"deprecated. It will be ignored and set " \
"automatically by open_files "
if isinstance(filenames, basestring):
filenames = [filenames]
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
@ -589,17 +596,16 @@ def open_files(filenames,
multi_file_reader_name = unique_name('multi_file_reader')
startup_blk = default_startup_program().current_block()
startup_reader = startup_blk.create_var(name=multi_file_reader_name)
attrs = {
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames
}
if is_test is not None:
attrs['is_test'] = is_test
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_reader]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num,
'buffer_size': buffer_size
})
type='open_files', outputs={'Out': [startup_reader]}, attrs=attrs)
startup_reader.desc.set_dtypes(dtypes)
startup_reader.persistable = True

@ -31,7 +31,10 @@ def load_vocab(filename):
# load word dict with paddle inner function
word_dict = load_vocab(sys.argv[1])
if len(sys.argv) == 1:
word_dict = paddle.dataset.imdb.word_dict()
else:
word_dict = load_vocab(sys.argv[1])
word_dict["<unk>"] = len(word_dict)
print "Dict dim = ", len(word_dict)

@ -41,16 +41,14 @@ def network_cfg(is_train, pass_num=100):
pass_num=pass_num,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'],
thread_num=1)
dtypes=['int64', 'int64'])
test_file_obj = fluid.layers.open_files(
filenames=TEST_FILES,
pass_num=1,
shapes=[[-1, 1], [-1, 1]],
lod_levels=[1, 0],
dtypes=['int64', 'int64'],
thread_num=1)
dtypes=['int64', 'int64'])
if is_train:
file_obj = fluid.layers.shuffle(train_file_obj, buffer_size=1000)

@ -39,17 +39,17 @@ class TestMultipleReader(unittest.TestCase):
copyfile('./mnist_0.recordio', './mnist_1.recordio')
copyfile('./mnist_0.recordio', './mnist_2.recordio')
def main(self, thread_num):
def main(self, is_test=False):
file_list = [
'./mnist_0.recordio', './mnist_1.recordio', './mnist_2.recordio'
]
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_files = fluid.layers.open_files(
filenames=file_list,
thread_num=thread_num,
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
dtypes=['float32', 'int64'],
is_test=is_test)
img, label = fluid.layers.read_file(data_files)
if fluid.core.is_compiled_with_cuda():
@ -71,6 +71,9 @@ class TestMultipleReader(unittest.TestCase):
self.assertEqual(batch_count, self.num_batch * 3)
def test_main(self):
self.main(thread_num=3) # thread number equals to file number
self.main(thread_num=10) # thread number is larger than file number
self.main(thread_num=2) # thread number is less than file number
self.main(is_test=False)
self.main(is_test=True)
if __name__ == '__main__':
unittest.main()

@ -32,9 +32,7 @@ def simple_fc_net(use_feed):
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)
hidden = img
@ -60,9 +58,7 @@ def fc_with_batchnorm(use_feed):
filenames=[MNIST_RECORDIO_FILE],
shapes=[[-1, 784], [-1, 1]],
lod_levels=[0, 0],
dtypes=['float32', 'int64'],
thread_num=1,
for_parallel=True)
dtypes=['float32', 'int64'])
reader = fluid.layers.io.double_buffer(reader)
img, label = fluid.layers.read_file(reader)

Loading…
Cancel
Save