@ -27,6 +27,7 @@ from .dataloader.dataloader_iter import _DataLoaderIterSingleProcess, _DataLoade
from . dataloader . batch_sampler import _InfiniteIterableSampler
from . layers . io import monkey_patch_reader_methods , _copy_reader_var_ , double_buffer
from . unique_name import UniqueNameGenerator
from . framework import _get_paddle_place , _get_paddle_place_list
import logging
import warnings
@ -186,10 +187,12 @@ class DataLoader(object):
The Tensors should be created by : code : ` paddle . static . data ( ) ` .
: attr : ` feed_list ` must be set if : attr : ` return_list ` is
False . Default None .
places ( list ( Place ) | tuple ( Place ) | optional ) : a list of Place ,
places ( list ( Place ) | tuple ( Place ) | list ( str ) | optional ) : a list of Place ,
to put data onto , : attr : ` places ` can be None , if
: attr : ` places ` is None , default place ( CPUPlace or CUDAPlace ( 0 ) )
will be used . Default None .
will be used . Default None . If ` ` places ` ` is list of string ,
the string in the list can be ` ` cpu ` ` , ` ` gpu : x ` ` and ` ` gpu_pinned ` ` ,
where ` ` x ` ` is the index of the GPUs .
return_list ( bool ) : whether the return value on each device is
presented as a list . If : attr : ` return_list = False ` , the return
value on each device would be a dict of str - > Tensor , where
@ -335,6 +338,10 @@ class DataLoader(object):
if places is None :
places = _current_expected_place ( )
if isinstance ( places , ( list , tuple ) ) :
places = _get_paddle_place_list ( places )
else :
places = _get_paddle_place ( places )
self . places = _convert_places ( places )
assert num_workers > = 0 , " num_workers should be a non-negative value "
@ -752,8 +759,9 @@ class DataLoader(object):
Args :
dataset ( InMemoryDataset | QueueDataset ) : the dataset object .
places ( list ( CUDAPlace ) | list ( CPUPlace ) ) : places where the result
data should be converted .
places ( list ( CUDAPlace ) | list ( CPUPlace ) | list ( str ) ) : places where the result
data should be converted . If places is list of string , the string in the list
can be ` ` cpu ` ` , ` ` gpu : x ` ` and ` ` gpu_pinned ` ` , where x is the index of the GPUs .
drop_last ( bool ) : whether to drop the last batch whose sample
number is less than batch size . If drop_last = True , they
would be dropped . If drop_last = False , they would be kept .
@ -1030,6 +1038,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
drop_last = True ,
places = None ) :
assert batch_size > 0 , " batch_size must be larger than 0 "
if isinstance ( places , ( list , tuple ) ) :
places = _get_paddle_place_list ( places )
else :
places = _get_paddle_place ( places )
self . set_sample_list_generator (
paddle . batch (
reader , batch_size = batch_size , drop_last = drop_last ) ,
@ -1037,6 +1049,11 @@ class DygraphGeneratorLoader(DataLoaderBase):
return self
def set_sample_list_generator ( self , reader , places = None ) :
if isinstance ( places , ( list , tuple ) ) :
places = _get_paddle_place_list ( places )
else :
places = _get_paddle_place ( places )
def __batch_reader_impl__ ( ) :
for batch in reader ( ) :
slots = [ ]
@ -1052,6 +1069,10 @@ class DygraphGeneratorLoader(DataLoaderBase):
return self
def set_batch_generator ( self , reader , places = None ) :
if isinstance ( places , ( list , tuple ) ) :
places = _get_paddle_place_list ( places )
else :
places = _get_paddle_place ( places )
self . _batch_reader = reader
if places is None :
places = _current_expected_place ( )
@ -1275,6 +1296,10 @@ class GeneratorLoader(DataLoaderBase):
drop_last = True ,
places = None ) :
assert batch_size > 0 , " batch_size must be larger than 0 "
if isinstance ( places , ( list , tuple ) ) :
places = _get_paddle_place_list ( places )
else :
places = _get_paddle_place ( places )
has_lod = False
for f in self . _feed_list :
if f . lod_level != 0 :
@ -1297,6 +1322,10 @@ class GeneratorLoader(DataLoaderBase):
return self
def set_sample_list_generator ( self , reader , places = None ) :
if isinstance ( places , ( list , tuple ) ) :
places = _get_paddle_place_list ( places )
else :
places = _get_paddle_place ( places )
with program_guard ( Program ( ) , Program ( ) ) :
feeder = DataFeeder (
feed_list = self . _feed_list , place = core . CPUPlace ( ) )
@ -1310,6 +1339,10 @@ class GeneratorLoader(DataLoaderBase):
return self
def set_batch_generator ( self , reader , places = None ) :
if isinstance ( places , ( list , tuple ) ) :
places = _get_paddle_place_list ( places )
else :
places = _get_paddle_place ( places )
self . _tensor_reader = reader
if self . _iterable :
assert places is not None , " Places cannot be None when DataLoader is iterable "
@ -1784,6 +1817,10 @@ class DatasetLoader(DataLoaderBase):
DatasetBase ) , " dataset must be type of DatasetBase "
assert not in_dygraph_mode (
) , " DatasetLoader is not supported in dygraph mode yet "
if isinstance ( places , ( list , tuple ) ) :
places = _get_paddle_place_list ( places )
else :
places = _get_paddle_place ( places )
thread_num = len ( places )