refine dataset API

test=develop
revert-16555-model_data_cryption_link_all_lib
dongdaxiang 6 years ago
parent 359fec0567
commit 2c5839f7e3

@ -15,7 +15,7 @@
from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format
from . import core
__all__ = ['DatasetFactory']
__all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
class DatasetFactory(object):
@ -38,6 +38,10 @@ class DatasetFactory(object):
"""
Create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset".
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
"""
try:
dataset = globals()[datafeed_class]()
@ -177,7 +181,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase):
"""
InMemoryDataset, it will load data into memory
and shuffle data before training
and shuffle data before training.
This class should be created by DatasetFactory
Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
@ -259,7 +264,8 @@ class QueueDataset(DatasetBase):
def __init__(self):
"""
Init
Initialize QueueDataset
This class should be created by DatasetFactory
"""
super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed"
@ -268,7 +274,8 @@ class QueueDataset(DatasetBase):
"""
Local shuffle
QueueDataset does not support local shuffle
Local shuffle is not supported in QueueDataset
NotImplementedError will be raised
"""
raise NotImplementedError(
"QueueDataset does not support local shuffle, "
@ -276,7 +283,8 @@ class QueueDataset(DatasetBase):
def global_shuffle(self, fleet=None):
"""
Global shuffle
Global shuffle is not supported in QueueDataset
NotImplementedError will be raised
"""
raise NotImplementedError(
"QueueDataset does not support global shuffle, "

Loading…
Cancel
Save