|
|
@ -145,7 +145,7 @@ class DatasetHelper:
|
|
|
|
self.iter = iterclass(dataset, sink_size, epoch_num)
|
|
|
|
self.iter = iterclass(dataset, sink_size, epoch_num)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
iterclass = _DatasetIterNormal
|
|
|
|
iterclass = _DatasetIterNormal
|
|
|
|
self.iter = iterclass(dataset)
|
|
|
|
self.iter = iterclass(dataset, epoch_num=epoch_num)
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
def __iter__(self):
|
|
|
|
return self.iter.__iter__()
|
|
|
|
return self.iter.__iter__()
|
|
|
@ -290,11 +290,12 @@ class _DatasetIterPSLite(_DatasetIter):
|
|
|
|
|
|
|
|
|
|
|
|
class _DatasetIterNormal:
|
|
|
|
class _DatasetIterNormal:
|
|
|
|
"""Iter for normal(non sink) mode, feed the data from host."""
|
|
|
|
"""Iter for normal(non sink) mode, feed the data from host."""
|
|
|
|
def __init__(self, dataset):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dataset, epoch_num=-1):
|
|
|
|
self.dataset = dataset
|
|
|
|
self.dataset = dataset
|
|
|
|
self.device_num = _get_device_num()
|
|
|
|
self.device_num = _get_device_num()
|
|
|
|
self.global_rank = _get_global_rank()
|
|
|
|
self.global_rank = _get_global_rank()
|
|
|
|
self.iter = self.dataset.create_tuple_iterator()
|
|
|
|
self.iter = self.dataset.create_tuple_iterator(num_epochs=epoch_num)
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
def __iter__(self):
|
|
|
|
return self
|
|
|
|
return self
|
|
|
|