!10137 fix GeneratorDataset multiprocessing hangs

From: @heleiwang
Reviewed-by: 
Signed-off-by:
pull/10137/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ea2cabcfec

@ -22,10 +22,12 @@ import glob
import json import json
import math import math
import os import os
import signal
import uuid import uuid
import multiprocessing import multiprocessing
import queue import queue
from enum import Enum from enum import Enum
from functools import partial
from importlib import import_module from importlib import import_module
import sys import sys
import threading import threading
@ -3447,6 +3449,7 @@ class SamplerFn:
self.workers = [] self.workers = []
self.num_worker = num_worker self.num_worker = num_worker
self.multi_process = multi_process self.multi_process = multi_process
self.joined = False
# Event for end of epoch # Event for end of epoch
if multi_process is True: if multi_process is True:
self.eof = multiprocessing.Event() self.eof = multiprocessing.Event()
@ -3485,29 +3488,47 @@ class SamplerFn:
# Fetch results # Fetch results
for i in range(len(indices)): for i in range(len(indices)):
if self.eof.is_set():
self._stop_subprocess()
return
# Fetch result and put index # Fetch result and put index
try: try:
result = self.workers[i % self.num_worker].get() result = self.workers[i % self.num_worker].get()
except queue.Empty: except queue.Empty:
self._stop_subprocess()
raise Exception("Generator worker process timeout.") raise Exception("Generator worker process timeout.")
except KeyboardInterrupt: except KeyboardInterrupt:
self.eof.set() self._stop_subprocess()
for w in self.workers:
w.terminate()
w.join()
raise Exception("Generator worker receives KeyboardInterrupt.") raise Exception("Generator worker receives KeyboardInterrupt.")
if self.eof.is_set():
self._stop_subprocess()
return
if idx_cursor < len(indices): if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
yield tuple([np.array(x, copy=False) for x in result]) yield tuple([np.array(x, copy=False) for x in result])
def __del__(self): def _stop_subprocess(self):
self.eof.set() self.eof.set()
if self.joined is False:
for w in self.workers:
w.join()
self.joined = True
def __del__(self):
self._stop_subprocess()
def _generator_worker_loop(dataset, idx_queue, result_queue, eof): def _subprocess_handle(eof, signum, frame):
logger.info("The subprocess receives a termination signal.")
eof.set()
def _generator_worker_loop(dataset, idx_queue, result_queue, eof, is_multiprocessing):
""" """
Multithread or multiprocess generator worker process loop. Multithread or multiprocess generator worker process loop.
""" """
if is_multiprocessing:
signal.signal(signal.SIGTERM, partial(_subprocess_handle, eof))
while True: while True:
# Fetch index, block # Fetch index, block
try: try:
@ -3516,6 +3537,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
raise Exception("Generator worker receives KeyboardInterrupt.") raise Exception("Generator worker receives KeyboardInterrupt.")
except queue.Empty: except queue.Empty:
if eof.is_set(): if eof.is_set():
if is_multiprocessing:
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
return return
# If end-of-file (eof) is not set, continue to get data from idx_queue # If end-of-file (eof) is not set, continue to get data from idx_queue
continue continue
@ -3525,6 +3549,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
assert eof.is_set(), "" assert eof.is_set(), ""
return return
if eof.is_set(): if eof.is_set():
if is_multiprocessing:
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
return return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process # Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx] result = dataset[idx]
@ -3536,6 +3563,9 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eof):
raise Exception("Generator worker receives KeyboardInterrupt.") raise Exception("Generator worker receives KeyboardInterrupt.")
except queue.Full: except queue.Full:
if eof.is_set(): if eof.is_set():
if is_multiprocessing:
idx_queue.cancel_join_thread()
result_queue.cancel_join_thread()
return return
# If eof is not set, continue to put data to result_queue # If eof is not set, continue to put data to result_queue
continue continue
@ -3551,7 +3581,7 @@ class _GeneratorWorkerMt(threading.Thread):
def __init__(self, dataset, eof): def __init__(self, dataset, eof):
self.idx_queue = queue.Queue(16) self.idx_queue = queue.Queue(16)
self.res_queue = queue.Queue(16) self.res_queue = queue.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, False))
def put(self, item): def put(self, item):
""" """
@ -3567,10 +3597,10 @@ class _GeneratorWorkerMt(threading.Thread):
def queue_empty(self): def queue_empty(self):
if not self.idx_queue.empty(): if not self.idx_queue.empty():
logger.error("idx_queue is not empty") logger.warning("idx_queue is not empty")
return False return False
if not self.res_queue.empty(): if not self.res_queue.empty():
logger.error("res_queue is not empty") logger.warning("res_queue is not empty")
return False return False
return True return True
@ -3583,7 +3613,7 @@ class _GeneratorWorkerMp(multiprocessing.Process):
def __init__(self, dataset, eof): def __init__(self, dataset, eof):
self.idx_queue = multiprocessing.Queue(16) self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16) self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof)) super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eof, True))
def put(self, item): def put(self, item):
""" """
@ -3601,21 +3631,13 @@ class _GeneratorWorkerMp(multiprocessing.Process):
def queue_empty(self): def queue_empty(self):
if not self.idx_queue.empty(): if not self.idx_queue.empty():
logger.error("idx_queue is not empty.") logger.warning("idx_queue is not empty.")
return False return False
if not self.res_queue.empty(): if not self.res_queue.empty():
logger.error("res_queue is not empty.") logger.warning("res_queue is not empty.")
return False return False
return True return True
def __del__(self):
# Try to destruct here, sometimes the class itself will be destructed in advance,
# so "self" will be a NoneType
try:
self.terminate()
except AttributeError:
pass
class GeneratorDataset(MappableDataset): class GeneratorDataset(MappableDataset):
""" """

Loading…
Cancel
Save