|
|
|
@ -22,10 +22,12 @@ import glob
|
|
|
|
|
import json
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
import signal
|
|
|
|
|
import uuid
|
|
|
|
|
import multiprocessing
|
|
|
|
|
import queue
|
|
|
|
|
from enum import Enum
|
|
|
|
|
from functools import partial
|
|
|
|
|
from importlib import import_module
|
|
|
|
|
import sys
|
|
|
|
|
import threading
|
|
|
|
@ -3447,6 +3449,7 @@ class SamplerFn:
|
|
|
|
|
self.workers = []
|
|
|
|
|
self.num_worker = num_worker
|
|
|
|
|
self.multi_process = multi_process
|
|
|
|
|
self.joined = False
|
|
|
|
|
# Event for end of epoch
|
|
|
|
|
if multi_process is True:
|
|
|
|
|
self.eof = multiprocessing.Event()
|
|
|
|
@ -3485,29 +3488,47 @@ class SamplerFn:
|
|
|
|
|
|
|
|
|
|
# Fetch results
|
|
|
|
|
for i in range(len(indices)):
|
|
|
|
|
if self.eof.is_set():
|
|
|
|
|
self._stop_subprocess()
|
|
|
|
|
return
|
|
|
|
|
# Fetch result and put index
|
|
|
|
|
try:
|
|
|
|
|
result = self.workers[i % self.num_worker].get()
|
|
|
|
|
except queue.Empty:
|
|
|
|
|
self._stop_subprocess()
|
|
|
|
|
raise Exception("Generator worker process timeout.")
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
self.eof.set()
|
|
|
|
|
for w in self.workers:
|
|
|
|
|
w.terminate()
|
|
|
|
|
w.join()
|
|
|
|
|
self._stop_subprocess()
|
|
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt.")
|
|
|
|
|
if self.eof.is_set():
|
|
|
|
|
self._stop_subprocess()
|
|
|
|
|
return
|
|
|
|
|
if idx_cursor < len(indices):
|
|
|
|
|
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
|
|
|
|
|
yield tuple([np.array(x, copy=False) for x in result])
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
def _stop_subprocess(self):
|
|
|
|
|
self.eof.set()
|
|
|
|
|
if self.joined is False:
|
|
|
|
|
for w in self.workers:
|
|
|
|
|
w.join()
|
|
|
|
|
self.joined = True
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
self._stop_subprocess()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
|
|
|
|
def _subprocess_handle(eof, signum, frame):
|
|
|
|
|
logger.info("The subprocess receives a termination signal.")
|
|
|
|
|
eof.set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing):
|
|
|
|
|
"""
|
|
|
|
|
Multithread or multiprocess generator worker process loop.
|
|
|
|
|
"""
|
|
|
|
|
if is_multiprocessing:
|
|
|
|
|
signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof))
|
|
|
|
|
while True:
|
|
|
|
|
# Fetch index, block
|
|
|
|
|
try:
|
|
|
|
@ -3516,6 +3537,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
|
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt.")
|
|
|
|
|
except queue.Empty:
|
|
|
|
|
if eof.is_set():
|
|
|
|
|
if is_multiprocessing:
|
|
|
|
|
idx_queue.cancel_join_thread()
|
|
|
|
|
result_queue.cancel_join_thread()
|
|
|
|
|
return
|
|
|
|
|
# If end-of-file (eof) is not set, continue to get data from idx_queue
|
|
|
|
|
continue
|
|
|
|
@ -3525,6 +3549,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
|
|
|
|
assert eof.is_set(), ""
|
|
|
|
|
return
|
|
|
|
|
if eof.is_set():
|
|
|
|
|
if is_multiprocessing:
|
|
|
|
|
idx_queue.cancel_join_thread()
|
|
|
|
|
result_queue.cancel_join_thread()
|
|
|
|
|
return
|
|
|
|
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
|
|
|
|
result = dataset[idx]
|
|
|
|
@ -3536,6 +3563,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
|
|
|
|
|
raise Exception("Generator worker receives KeyboardInterrupt.")
|
|
|
|
|
except queue.Full:
|
|
|
|
|
if eof.is_set():
|
|
|
|
|
if is_multiprocessing:
|
|
|
|
|
idx_queue.cancel_join_thread()
|
|
|
|
|
result_queue.cancel_join_thread()
|
|
|
|
|
return
|
|
|
|
|
# If eof is not set, continue to put data to result_queue
|
|
|
|
|
continue
|
|
|
|
@ -3551,7 +3581,7 @@ class _GeneratorWorkerMt(threading.Thread):
|
|
|
|
|
def __init__(self, dataset, eof):
|
|
|
|
|
self.idx_queue = queue.Queue(16)
|
|
|
|
|
self.res_queue = queue.Queue(16)
|
|
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
|
|
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False))
|
|
|
|
|
|
|
|
|
|
def put(self, item):
|
|
|
|
|
"""
|
|
|
|
@ -3567,10 +3597,10 @@ class _GeneratorWorkerMt(threading.Thread):
|
|
|
|
|
|
|
|
|
|
def queue_empty(self):
|
|
|
|
|
if not self.idx_queue.empty():
|
|
|
|
|
logger.error("idx_queue is not empty")
|
|
|
|
|
logger.warning("idx_queue is not empty")
|
|
|
|
|
return False
|
|
|
|
|
if not self.res_queue.empty():
|
|
|
|
|
logger.error("res_queue is not empty")
|
|
|
|
|
logger.warning("res_queue is not empty")
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
@ -3583,7 +3613,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|
|
|
|
def __init__(self, dataset, eof):
|
|
|
|
|
self.idx_queue = multiprocessing.Queue(16)
|
|
|
|
|
self.res_queue = multiprocessing.Queue(16)
|
|
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof))
|
|
|
|
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True))
|
|
|
|
|
|
|
|
|
|
def put(self, item):
|
|
|
|
|
"""
|
|
|
|
@ -3601,21 +3631,13 @@ class _GeneratorWorkerMp(multiprocessing.Process):
|
|
|
|
|
|
|
|
|
|
def queue_empty(self):
|
|
|
|
|
if not self.idx_queue.empty():
|
|
|
|
|
logger.error("idx_queue is not empty.")
|
|
|
|
|
logger.warning("idx_queue is not empty.")
|
|
|
|
|
return False
|
|
|
|
|
if not self.res_queue.empty():
|
|
|
|
|
logger.error("res_queue is not empty.")
|
|
|
|
|
logger.warning("res_queue is not empty.")
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
|
# Try to destruct here, sometimes the class itself will be destructed in advance,
|
|
|
|
|
# so "self" will be a NoneType
|
|
|
|
|
try:
|
|
|
|
|
self.terminate()
|
|
|
|
|
except AttributeError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeneratorDataset(MappableDataset):
|
|
|
|
|
"""
|
|
|
|
|