|
|
@ -3450,6 +3450,7 @@ class SamplerFn:
|
|
|
|
self.num_worker = num_worker
|
|
|
|
self.num_worker = num_worker
|
|
|
|
self.multi_process = multi_process
|
|
|
|
self.multi_process = multi_process
|
|
|
|
self.joined = False
|
|
|
|
self.joined = False
|
|
|
|
|
|
|
|
self.ppid = os.getpid()
|
|
|
|
# Event for end of epoch
|
|
|
|
# Event for end of epoch
|
|
|
|
if multi_process is True:
|
|
|
|
if multi_process is True:
|
|
|
|
self.eof = multiprocessing.Event()
|
|
|
|
self.eof = multiprocessing.Event()
|
|
|
@ -3508,11 +3509,12 @@ class SamplerFn:
|
|
|
|
yield tuple([np.array(x, copy=False) for x in result])
|
|
|
|
yield tuple([np.array(x, copy=False) for x in result])
|
|
|
|
|
|
|
|
|
|
|
|
def _stop_subprocess(self):
|
|
|
|
def _stop_subprocess(self):
|
|
|
|
|
|
|
|
# Only the main process can call join
|
|
|
|
|
|
|
|
if self.joined is False and self.ppid == os.getpid():
|
|
|
|
self.eof.set()
|
|
|
|
self.eof.set()
|
|
|
|
if self.joined is False:
|
|
|
|
self.joined = True
|
|
|
|
for w in self.workers:
|
|
|
|
for w in self.workers:
|
|
|
|
w.join()
|
|
|
|
w.join()
|
|
|
|
self.joined = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
def __del__(self):
|
|
|
|
self._stop_subprocess()
|
|
|
|
self._stop_subprocess()
|
|
|
|