fix multithreading error in GeneratorDataset

pull/12795/head
heleiwang 4 years ago
parent 0f3c5b1d0f
commit 75bf8a4714

@ -3174,7 +3174,7 @@ class SamplerFn:
self.workers = []
self.num_worker = num_worker
self.multi_process = multi_process
self.joined = False
self.need_join = False
self.ppid = os.getpid()
self.pid = []
# Event for end of epoch
@ -3192,6 +3192,7 @@ class SamplerFn:
# In this phase, the main process is not locked.
worker.start()
self.pid.append(worker.pid)
self.need_join = True
else:
worker = _GeneratorWorkerMt(dataset, self.eof)
worker.daemon = True
@ -3237,9 +3238,9 @@ class SamplerFn:
def _stop_subprocess(self):
# Only the main process can call join
if self.joined is False and self.ppid == os.getpid():
if self.need_join is True and self.ppid == os.getpid():
self.eof.set()
self.joined = True
self.need_join = False
for w in self.workers:
w.join()

Loading…
Cancel
Save