Merge pull request #16597 from guru4elephant/refine_dataset

refine dataset API
revert-16555-model_data_cryption_link_all_lib
guru4elephant 7 years ago committed by GitHub
commit be61e9eab8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,7 +15,7 @@
from paddle.fluid.proto import data_feed_pb2 from paddle.fluid.proto import data_feed_pb2
from google.protobuf import text_format from google.protobuf import text_format
from . import core from . import core
__all__ = ['DatasetFactory'] __all__ = ['DatasetFactory', 'InMemoryDataset', 'QueueDataset']
class DatasetFactory(object): class DatasetFactory(object):
@ -38,6 +38,10 @@ class DatasetFactory(object):
""" """
Create "QueueDataset" or "InMemoryDataset", Create "QueueDataset" or "InMemoryDataset",
the default is "QueueDataset". the default is "QueueDataset".
Examples:
import paddle.fluid as fluid
dataset = fluid.DatasetFactory().create_dataset()
""" """
try: try:
dataset = globals()[datafeed_class]() dataset = globals()[datafeed_class]()
@ -177,7 +181,8 @@ class DatasetBase(object):
class InMemoryDataset(DatasetBase): class InMemoryDataset(DatasetBase):
""" """
InMemoryDataset, it will load data into memory InMemoryDataset, it will load data into memory
and shuffle data before training and shuffle data before training.
This class should be created by DatasetFactory
Example: Example:
dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset") dataset = paddle.fluid.DatasetFactory.create_dataset("InMemoryDataset")
@ -259,7 +264,8 @@ class QueueDataset(DatasetBase):
def __init__(self): def __init__(self):
""" """
Init Initialize QueueDataset
This class should be created by DatasetFactory
""" """
super(QueueDataset, self).__init__() super(QueueDataset, self).__init__()
self.proto_desc.name = "MultiSlotDataFeed" self.proto_desc.name = "MultiSlotDataFeed"
@ -268,7 +274,8 @@ class QueueDataset(DatasetBase):
""" """
Local shuffle Local shuffle
QueueDataset does not support local shuffle Local shuffle is not supported in QueueDataset
NotImplementedError will be raised
""" """
raise NotImplementedError( raise NotImplementedError(
"QueueDataset does not support local shuffle, " "QueueDataset does not support local shuffle, "
@ -276,7 +283,8 @@ class QueueDataset(DatasetBase):
def global_shuffle(self, fleet=None): def global_shuffle(self, fleet=None):
""" """
Global shuffle Global shuffle is not supported in QueueDataset
NotImplementedError will be raised
""" """
raise NotImplementedError( raise NotImplementedError(
"QueueDataset does not support global shuffle, " "QueueDataset does not support global shuffle, "

Loading…
Cancel
Save