Make shared decorated readers' creater be only in main_program

fea/docker_cudnn7
fengjiayi 7 years ago
parent 3f90a583b4
commit 5416bac5d8

@ -109,7 +109,9 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase {
auto place_str = Attr<std::string>("place");
platform::Place place;
if (place_str == "CPU") {
if (place_str == "AUTO") {
place = dev_place;
} else if (place_str == "CPU") {
place = platform::CPUPlace();
} else {
std::istringstream sin(place_str);
@ -140,8 +142,9 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
enum_range.insert(string::Sprintf("CUDA:%d", i));
}
enum_range.insert("CPU");
AddAttr<std::string>("place", "The double buffer place, default is CPU")
.SetDefault("CPU")
enum_range.insert("AUTO");
AddAttr<std::string>("place", "The double buffer place")
.SetDefault("AUTO")
.InEnum({enum_range});
}
};

@ -440,7 +440,7 @@ def open_files(filenames,
return monkey_patch_reader_methods(main_prog_reader)
def __create_decorated_reader__(op_type, reader, attrs={}):
def __create_unshared_decorated_reader__(op_type, reader, attrs={}):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
@ -456,26 +456,40 @@ def __create_decorated_reader__(op_type, reader, attrs={}):
return monkey_patch_reader_methods(main_prog_var)
def __create_shared_decorated_reader__(op_type, reader, attrs={}):
new_reader_name = unique_name(op_type)
main_blk = default_main_program().current_block()
new_reader = main_blk.create_var(name=new_reader_name)
main_blk.append_op(
type=op_type,
inputs={'UnderlyingReader': reader},
outputs={'Out': [new_reader]},
attrs=attrs)
new_reader.persistable = True
new_reader.stop_gradient = True
return monkey_patch_reader_methods(new_reader)
def shuffle(reader, buffer_size):
return __create_decorated_reader__('create_shuffle_reader', reader,
{'buffer_size': int(buffer_size)})
return __create_unshared_decorated_reader__(
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
def double_buffer(reader, place=None):
attrs = dict()
if place is not None:
attrs['place'] = str(place).upper()
return __create_decorated_reader__('create_double_buffer_reader', reader,
attrs)
return __create_unshared_decorated_reader__('create_double_buffer_reader',
reader, attrs)
def multi_pass(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)})
return __create_shared_decorated_reader__(
'create_multi_pass_reader', reader, {'pass_num': int(pass_num)})
def for_parallel(reader):
return __create_decorated_reader__('create_threaded_reader', reader)
return __create_shared_decorated_reader__('create_threaded_reader', reader)
def read_file(file_obj):

Loading…
Cancel
Save