@ -3239,21 +3239,19 @@ def _cpp_sampler_fn(sampler, dataset):
yield tuple ( [ np . array ( x , copy = False ) for x in val ] )
def _cpp_sampler_fn_mp ( sampler , dataset, num_worker , multi_process ) :
def _cpp_sampler_fn_mp ( sampler , sample_fn ) :
"""
Multiprocessing generator function wrapper for mappable dataset with cpp sampler .
"""
indices = sampler . get_indices ( )
sample_fn = SamplerFn ( dataset , num_worker , multi_process )
return sample_fn . process ( indices )
def _py_sampler_fn_mp ( sampler , num_samples , dataset, num_worker , multi_process ) :
def _py_sampler_fn_mp ( sampler , num_samples , sample_fn ) :
"""
Multiprocessing generator function wrapper for mappable dataset with Python sampler .
"""
indices = _fetch_py_sampler_indices ( sampler , num_samples )
sample_fn = SamplerFn ( dataset , num_worker , multi_process )
return sample_fn . process ( indices )
@ -3299,17 +3297,21 @@ class SamplerFn:
self . multi_process = multi_process
# Event for end of epoch
if multi_process is True :
self . eo e = multiprocessing . Event ( )
self . eo f = multiprocessing . Event ( )
else :
self . eoe = threading . Event ( )
self . eof = threading . Event ( )
# Create workers
for _ in range ( num_worker ) :
if multi_process is True :
worker = _GeneratorWorkerMp ( dataset , self . eoe )
worker = _GeneratorWorkerMp ( dataset , self . eof )
worker . daemon = True
# When multi processes fork a subprocess, the lock of the main process is copied to the subprocess,
# which may cause deadlock. Therefore, the subprocess startup is performed in che initialization phase.
# In this phase, the main process is not locked.
worker . start ( )
else :
worker = _GeneratorWorkerMt ( dataset , self . eoe , self . eof )
worker . daemon = True
worker = _GeneratorWorkerMt ( dataset , self . eo f)
worker . daemon = True
self . workers . append ( worker )
def process ( self , indices ) :
@ -3317,14 +3319,18 @@ class SamplerFn:
The main process , start the child process or child thread , and fill the index queue .
Get the result and return .
"""
for w in self . workers :
# Check whether the queue of the subprocess is empty.
if not w . queue_empty ( ) :
raise Exception ( " The queue of the subprocess is not empty. " )
# Start all workers
if not w . is_alive ( ) :
w . start ( )
# Fill initial index queues
idx_cursor = 0
idx_cursor = _fill_worker_indices ( self . workers , indices , idx_cursor )
# Start all workers
for w in self . workers :
w . start ( )
# Fetch results
for i in range ( len ( indices ) ) :
# Fetch result and put index
@ -3340,64 +3346,31 @@ class SamplerFn:
raise Exception ( " Generator worker receives KeyboardInterrupt " )
if idx_cursor < len ( indices ) :
idx_cursor = _fill_worker_indices ( self . workers , indices , idx_cursor )
# Set end-of-epoch (eoe) event once all indices are sent
if idx_cursor == len ( indices ) and not self . eoe . is_set ( ) :
self . eoe . set ( )
yield tuple ( [ np . array ( x , copy = False ) for x in result ] )
def __del__ ( self ) :
self . eoe . set ( )
if self . multi_process is False :
self . eof . set ( )
for w in self . workers :
w . join ( )
def _generator_worker_loop_mp ( dataset , idx_queue , result_queue , eoe ) :
"""
Multiprocessing generator worker process loop
"""
while True :
# Fetch index, block
try :
idx = idx_queue . get ( )
except KeyboardInterrupt :
raise Exception ( " Generator worker receives KeyboardInterrupt " )
if idx is None :
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eoe . is_set ( ) , " "
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset [ idx ]
# Send data, block
try :
result_queue . put ( result )
except KeyboardInterrupt :
raise Exception ( " Generator worker receives KeyboardInterrupt " )
del result , idx
self . eof . set ( )
def _generator_worker_loop _mt ( dataset , idx_queue , result_queu e, eo e, eof ) :
def _generator_worker_loop ( dataset , idx_queue , result_queue , eof ) :
"""
Multithread generator worker process loop .
Multithread or multiprocess generator worker process loop .
"""
while True :
# Fetch index, block
try :
# Index is generated very fast, so the timeout is very short
idx = idx_queue . get ( timeout = 0.01 )
idx = idx_queue . get ( timeout = 1 )
except KeyboardInterrupt :
raise Exception ( " Generator worker receives KeyboardInterrupt " )
except queue . Empty :
if eof . is_set ( ) or eoe . is_set ( ) :
if eof . is_set ( ) :
return
# If end-of- epoch (eoe) or end-of- file (eof) is not set, continue to get data from idx_queue
# If end-of-file (eof) is not set, continue to get data from idx_queue
continue
if idx is None :
# When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set.
assert eo e . is_set ( ) , " "
# Upon receiving None, worker process should check if eof is set.
assert eof . is_set ( ) , " "
return
if eof . is_set ( ) :
return
@ -3416,8 +3389,6 @@ def _generator_worker_loop_mt(dataset, idx_queue, result_queue, eoe, eof):
continue
break
del result , idx
if eoe . is_set ( ) and idx_queue . empty ( ) :
return
class _GeneratorWorkerMt ( threading . Thread ) :
@ -3425,10 +3396,10 @@ class _GeneratorWorkerMt(threading.Thread):
Worker process for multithread Generator .
"""
def __init__ ( self , dataset , eo e, eo f) :
def __init__ ( self , dataset , eo f) :
self . idx_queue = queue . Queue ( 16 )
self . res_queue = queue . Queue ( 16 )
super ( ) . __init__ ( target = _generator_worker_loop _mt , args = ( dataset , self . idx_queue , self . res_queu e, eo e, eof ) )
super ( ) . __init__ ( target = _generator_worker_loop , args = ( dataset , self . idx_queue , self . res_queu e, eof ) )
def put ( self , item ) :
"""
@ -3442,16 +3413,25 @@ class _GeneratorWorkerMt(threading.Thread):
"""
return self . res_queue . get ( timeout = 30 )
def queue_empty ( self ) :
if not self . idx_queue . empty ( ) :
logger . error ( " idx_queue is not empty " )
return False
if not self . res_queue . empty ( ) :
logger . error ( " res_queue is not empty " )
return False
return True
class _GeneratorWorkerMp ( multiprocessing . Process ) :
"""
Worker process for multiprocess Generator .
"""
def __init__ ( self , dataset , eoe ) :
def __init__ ( self , dataset , eo f ) :
self . idx_queue = multiprocessing . Queue ( 16 )
self . res_queue = multiprocessing . Queue ( 16 )
super ( ) . __init__ ( target = _generator_worker_loop _mp , args = ( dataset , self . idx_queue , self . res_queue , eo e ) )
super ( ) . __init__ ( target = _generator_worker_loop , args = ( dataset , self . idx_queue , self . res_queue , eo f ) )
def put ( self , item ) :
"""
@ -3467,6 +3447,15 @@ class _GeneratorWorkerMp(multiprocessing.Process):
# when we run too many iterators with infinite epoch(num_epoch=-1)
return self . res_queue . get ( timeout = 30 )
def queue_empty ( self ) :
if not self . idx_queue . empty ( ) :
logger . error ( " idx_queue is not empty " )
return False
if not self . res_queue . empty ( ) :
logger . error ( " res_queue is not empty " )
return False
return True
def __del__ ( self ) :
# Try to destruct here, sometimes the class itself will be destructed in advance,
# so "self" will be a NoneType
@ -3657,16 +3646,14 @@ class GeneratorDataset(MappableDataset):
sampler_instance . set_num_rows ( len ( self . source ) )
sampler_instance . initialize ( )
if new_op . num_parallel_workers > 1 :
new_op . source = ( lambda : _cpp_sampler_fn_mp ( sampler_instance , self . source ,
new_op . num_parallel_workers ,
self . python_multiprocessing ) )
sample_fn = SamplerFn ( self . source , new_op . num_parallel_workers , self . python_multiprocessing )
new_op . source = ( lambda : _cpp_sampler_fn_mp ( sampler_instance , sample_fn ) )
else :
new_op . source = ( lambda : _cpp_sampler_fn ( sampler_instance , self . source ) )
else :
if new_op . num_parallel_workers > 1 :
new_op . source = ( lambda : _py_sampler_fn_mp ( new_op . sampler , new_op . num_samples , self . source ,
new_op . num_parallel_workers ,
self . python_multiprocessing ) )
sample_fn = SamplerFn ( self . source , new_op . num_parallel_workers , self . python_multiprocessing )
new_op . source = ( lambda : _py_sampler_fn_mp ( new_op . sampler , new_op . num_samples , sample_fn ) )
else :
new_op . source = ( lambda : _py_sampler_fn ( new_op . sampler , new_op . num_samples , self . source ) )
else :