|
|
|
|
@ -144,6 +144,8 @@ class WaitedDSCallback(Callback, DSCallback):
|
|
|
|
|
self.epoch_event = threading.Event()
|
|
|
|
|
self.epoch_run_context = None
|
|
|
|
|
|
|
|
|
|
self.training_ended = False
|
|
|
|
|
|
|
|
|
|
def sync_epoch_begin(self, train_run_context, ds_run_context):
|
|
|
|
|
"""
|
|
|
|
|
Called before a new dataset epoch is started and after the previous training epoch is ended.
|
|
|
|
|
@ -180,6 +182,7 @@ class WaitedDSCallback(Callback, DSCallback):
|
|
|
|
|
ds_run_context: Include some information of the pipeline.
|
|
|
|
|
"""
|
|
|
|
|
if ds_run_context.cur_epoch_num > 1:
|
|
|
|
|
if not self.training_ended:
|
|
|
|
|
success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
|
|
|
|
|
self.epoch_event.clear()
|
|
|
|
|
if not success:
|
|
|
|
|
@ -205,6 +208,7 @@ class WaitedDSCallback(Callback, DSCallback):
|
|
|
|
|
ds_run_context: Include some information of the pipeline.
|
|
|
|
|
"""
|
|
|
|
|
if ds_run_context.cur_step_num > self.step_size:
|
|
|
|
|
if not self.training_ended:
|
|
|
|
|
success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
|
|
|
|
|
self.step_event.clear()
|
|
|
|
|
if not success:
|
|
|
|
|
@ -233,3 +237,8 @@ class WaitedDSCallback(Callback, DSCallback):
|
|
|
|
|
raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
|
|
|
|
|
|
|
|
|
|
return c_cb
|
|
|
|
|
|
|
|
|
|
def end(self, run_context):
|
|
|
|
|
self.epoch_end(run_context)
|
|
|
|
|
self.step_end(run_context)
|
|
|
|
|
self.training_ended = True
|
|
|
|
|
|