|
|
|
@ -23,9 +23,9 @@ class DatasetFactory(object):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def create_dataset(self, datafeed_class):
|
|
|
|
|
datafeed_class = datafeed_class.capitalize()
|
|
|
|
|
try:
|
|
|
|
|
dataset = globals()[datafeed_class]()
|
|
|
|
|
return dataset
|
|
|
|
|
except:
|
|
|
|
|
raise ValueError("datafeed class %s does not exist" %
|
|
|
|
|
datafeed_class)
|
|
|
|
@ -37,6 +37,7 @@ class DatasetBase(object):
|
|
|
|
|
# to decide whether we need create in memory instance
|
|
|
|
|
self.proto_desc = data_feed_pb2.DataFeedDesc()
|
|
|
|
|
self.proto_desc.pipe_command = "cat"
|
|
|
|
|
self.dataset = core.Dataset()
|
|
|
|
|
|
|
|
|
|
def set_pipe_command(self, pipe_command):
|
|
|
|
|
"""
|
|
|
|
@ -60,17 +61,23 @@ class DatasetBase(object):
|
|
|
|
|
"""
|
|
|
|
|
self.proto_desc.batch_size = batch_size
|
|
|
|
|
|
|
|
|
|
def set_thread(self, thread_num):
|
|
|
|
|
self.dataset.set_thread_num(thread_num)
|
|
|
|
|
|
|
|
|
|
def set_filelist(self, filelist):
|
|
|
|
|
self.dataset.set_filelist(filelist)
|
|
|
|
|
|
|
|
|
|
def set_use_var(self, var_list):
|
|
|
|
|
multi_slot = self.proto_desc.multi_slot_desc()
|
|
|
|
|
multi_slot = self.proto_desc.multi_slot_desc
|
|
|
|
|
for var in var_list:
|
|
|
|
|
slot_var = multi_slot.add()
|
|
|
|
|
slot_var = multi_slot.slots.add()
|
|
|
|
|
slot_var.is_used = True
|
|
|
|
|
slot_var.name = var.name
|
|
|
|
|
if var.lod_level == 0:
|
|
|
|
|
slot_var.is_dense = True
|
|
|
|
|
if var.dtype == core.VarType.FP32:
|
|
|
|
|
if var.dtype == core.VarDesc.VarType.FP32:
|
|
|
|
|
slot_var.type = "float32"
|
|
|
|
|
elif var.dtype == core.VarType.INT64:
|
|
|
|
|
elif var.dtype == core.VarDesc.VarType.INT64:
|
|
|
|
|
slot_var.type = "uint64"
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
@ -93,17 +100,24 @@ class DatasetBase(object):
|
|
|
|
|
|
|
|
|
|
class InMemoryDataset(DatasetBase):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(InMemoryDataset.__init__())
|
|
|
|
|
self.proto_desc.name = "InMemoryDataFeed"
|
|
|
|
|
super(InMemoryDataset, self).__init__()
|
|
|
|
|
self.proto_desc.name = "MultiSlotInMemoryDataFeed"
|
|
|
|
|
|
|
|
|
|
def load_into_memory(self):
|
|
|
|
|
self.dataset.set_data_feed_desc(self.desc())
|
|
|
|
|
self.dataset.load_into_memory()
|
|
|
|
|
|
|
|
|
|
def local_shuffle(self):
|
|
|
|
|
pass
|
|
|
|
|
self.dataset.local_shuffle()
|
|
|
|
|
|
|
|
|
|
def global_shuffle(self):
|
|
|
|
|
pass
|
|
|
|
|
from .distributed import ps_instance
|
|
|
|
|
instance = ps_instance.PaddlePSInstance(1, 2)
|
|
|
|
|
self.dataset.set_trainer_num(instance.get_worker_num())
|
|
|
|
|
self.global_shuffle()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QueueDataset(DatasetBase):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(QueueDataset.__init__())
|
|
|
|
|
super(QueueDataset, self).__init__()
|
|
|
|
|
self.proto_desc.name = "MultiSlotDataFeed"
|
|
|
|
|