|
|
|
@ -440,7 +440,7 @@ def open_files(filenames,
|
|
|
|
|
return monkey_patch_reader_methods(main_prog_reader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __create_decorated_reader__(op_type, reader, attrs={}):
|
|
|
|
|
def __create_unshared_decorated_reader__(op_type, reader, attrs={}):
|
|
|
|
|
var_name = unique_name(op_type)
|
|
|
|
|
startup_blk = default_startup_program().current_block()
|
|
|
|
|
startup_var = startup_blk.create_var(name=var_name)
|
|
|
|
@ -456,26 +456,40 @@ def __create_decorated_reader__(op_type, reader, attrs={}):
|
|
|
|
|
return monkey_patch_reader_methods(main_prog_var)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __create_shared_decorated_reader__(op_type, reader, attrs={}):
|
|
|
|
|
new_reader_name = unique_name(op_type)
|
|
|
|
|
main_blk = default_main_program().current_block()
|
|
|
|
|
new_reader = main_blk.create_var(name=new_reader_name)
|
|
|
|
|
main_blk.append_op(
|
|
|
|
|
type=op_type,
|
|
|
|
|
inputs={'UnderlyingReader': reader},
|
|
|
|
|
outputs={'Out': [new_reader]},
|
|
|
|
|
attrs=attrs)
|
|
|
|
|
new_reader.persistable = True
|
|
|
|
|
new_reader.stop_gradient = True
|
|
|
|
|
return monkey_patch_reader_methods(new_reader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def shuffle(reader, buffer_size):
|
|
|
|
|
return __create_decorated_reader__('create_shuffle_reader', reader,
|
|
|
|
|
{'buffer_size': int(buffer_size)})
|
|
|
|
|
return __create_unshared_decorated_reader__(
|
|
|
|
|
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def double_buffer(reader, place=None):
|
|
|
|
|
attrs = dict()
|
|
|
|
|
if place is not None:
|
|
|
|
|
attrs['place'] = str(place).upper()
|
|
|
|
|
return __create_decorated_reader__('create_double_buffer_reader', reader,
|
|
|
|
|
attrs)
|
|
|
|
|
return __create_unshared_decorated_reader__('create_double_buffer_reader',
|
|
|
|
|
reader, attrs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def multi_pass(reader, pass_num):
|
|
|
|
|
return __create_decorated_reader__('create_multi_pass_reader', reader,
|
|
|
|
|
{'pass_num': int(pass_num)})
|
|
|
|
|
return __create_shared_decorated_reader__(
|
|
|
|
|
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def for_parallel(reader):
|
|
|
|
|
return __create_decorated_reader__('create_threaded_reader', reader)
|
|
|
|
|
return __create_shared_decorated_reader__('create_threaded_reader', reader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_file(file_obj):
|
|
|
|
|