refine dataset API

test=develop
revert-16555-model_data_cryption_link_all_lib
dongdaxiang 6 years ago
parent 359fec0567
commit 2c5839f7e3

@ -15,7 +15,7 @@
from paddle.fluid.proto import data_feed_pb2 from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
from . import core from . import core
__all__ = ['DatasetFactory'] __all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
class DatasetFactory(object): class DatasetFactory(object):
@ -38,6 +38,10 @@ class DatasetFactory(object):
""" """
Create "QueueDataset" or "InMemoryDataset", Create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset". the default is "QueueDataset".
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
""" """
try: try:
dataset = globals()[datafeed_class]() dataset = globals()[datafeed_class]()
@ -177,7 +181,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase): class InMemoryDataset(DatasetBase):
""" """
InMemoryDataset, it will load data into memory InMemoryDataset, it will load data into memory
and shuffle data before training and shuffle data before training.
This class should be created by DatasetFactory
Example: Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset") dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
@ -259,7 +264,8 @@ class QueueDataset(DatasetBase):
def __init__(self): def __init__(self):
""" """
Init Initialize QueueDataset
This class should be created by DatasetFactory
""" """
super(QueueDataset, self).__init__() super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed" self.proto_desc.name = "MultiSlotDataFeed"
@ -268,7 +274,8 @@ class QueueDataset(DatasetBase):
""" """
Local shuffle Local shuffle
QueueDataset does not support local shuffle Local shuffle is not supported in QueueDataset
NotImplementedError will be raised
""" """
raise NotImplementedError( raise NotImplementedError(
"QueueDataset does not support local shuffle, " "QueueDataset does not support local shuffle, "
@ -276,7 +283,8 @@ class QueueDataset(DatasetBase):
def global_shuffle(self, fleet=None): def global_shuffle(self, fleet=None):
""" """
Global shuffle Global shuffle is not supported in QueueDataset
NotImplementedError will be raised
""" """
raise NotImplementedError( raise NotImplementedError(
"QueueDataset does not support global shuffle, " "QueueDataset does not support global shuffle, "

Loading…
Cancel
Save