|
|
|
@ -15,7 +15,7 @@
|
|
|
|
|
from paddle.fluid.proto import data_feed_pb2
|
|
|
|
|
from google.protobuf import text_format
|
|
|
|
|
from . import core
|
|
|
|
|
__all__ = ['DatasetFactory']
|
|
|
|
|
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DatasetFactory(object):
|
|
|
|
@ -38,6 +38,10 @@ class DatasetFactory(object):
|
|
|
|
|
"""
|
|
|
|
|
Create "QueueDataset" or "InMemoryDataset",
|
|
|
|
|
the default is "QueueDataset".
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
dataset = fluid.DatasetFactory().create_dataset()
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
dataset = globals()[datafeed_class]()
|
|
|
|
@ -177,7 +181,8 @@ class DatasetBase(object):
|
|
|
|
|
class InMemoryDataset(DatasetBase):
|
|
|
|
|
"""
|
|
|
|
|
InMemoryDataset, it will load data into memory
|
|
|
|
|
and shuffle data before training
|
|
|
|
|
and shuffle data before training.
|
|
|
|
|
This class should be created by DatasetFactory
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
|
|
|
|
@ -259,7 +264,8 @@ class QueueDataset(DatasetBase):
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
"""
|
|
|
|
|
Init
|
|
|
|
|
Initialize QueueDataset
|
|
|
|
|
This class should be created by DatasetFactory
|
|
|
|
|
"""
|
|
|
|
|
super(QueueDataset, self).__init__()
|
|
|
|
|
self.proto_desc.name = "MultiSlotDataFeed"
|
|
|
|
@ -268,7 +274,8 @@ class QueueDataset(DatasetBase):
|
|
|
|
|
"""
|
|
|
|
|
Local shuffle
|
|
|
|
|
|
|
|
|
|
QueueDataset does not support local shuffle
|
|
|
|
|
Local shuffle is not supported in QueueDataset
|
|
|
|
|
NotImplementedError will be raised
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"QueueDataset does not support local shuffle, "
|
|
|
|
@ -276,7 +283,8 @@ class QueueDataset(DatasetBase):
|
|
|
|
|
|
|
|
|
|
def global_shuffle(self, fleet=None):
|
|
|
|
|
"""
|
|
|
|
|
Global shuffle
|
|
|
|
|
Global shuffle is not supported in QueueDataset
|
|
|
|
|
NotImplementedError will be raised
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"QueueDataset does not support global shuffle, "
|
|
|
|
|