You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/trainer_config_helpers/data_sources.py

204 lines
7.5 KiB

# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data Sources are helpers to define paddle training data or testing data.
"""
from paddle.trainer.config_parser import *
from .utils import deprecated
try:
import cPickle as pickle
except ImportError:
import pickle
__all__ = ['define_py_data_sources2']
def define_py_data_source(file_list, cls, module,
obj, args=None, async=False,
data_cls=PyData):
"""
Define a python data source.
For example, the simplest usage in trainer_config.py as follow:
.. code-block:: python
define_py_data_source("train.list", TrainData, "data_provider", "process")
Or. if you want to pass arguments from trainer_config to data_provider.py, then
.. code-block:: python
define_py_data_source("train.list", TrainData, "data_provider", "process",
args={"dictionary": dict_name})
:param data_cls:
:param file_list: file list name, which contains all data file paths
:type file_list: basestring
:param cls: Train or Test Class.
:type cls: TrainData or TestData
:param module: python module name.
:type module: basestring
:param obj: python object name. May be a function name if using
PyDataProviderWrapper.
:type obj: basestring
:param args: The best practice is using dict to pass arguments into
DataProvider, and use :code:`@init_hook_wrapper` to
receive arguments.
:type args: string or picklable object
:param async: Load Data asynchronously or not.
:type async: bool
:return: None
:rtype: None
"""
if isinstance(file_list, list):
file_list_name = 'train.list'
if isinstance(cls, TestData):
file_list_name = 'test.list'
with open(file_list_name, 'r') as f:
f.writelines(file_list)
file_list = file_list_name
if not isinstance(args, basestring) and args is not None:
args = pickle.dumps(args, 0)
if data_cls is None:
def py_data2(files, load_data_module, load_data_object, load_data_args,
**kwargs):
data = DataBase()
data.type = 'py2'
data.files = files
data.load_data_module = load_data_module
data.load_data_object = load_data_object
data.load_data_args = load_data_args
return data
data_cls = py_data2
cls(data_cls(files=file_list,
load_data_module=module,
load_data_object=obj,
load_data_args=args,
async_load_data=async))
def define_py_data_sources(train_list, test_list, module, obj, args=None,
train_async=False, data_cls=PyData):
"""
The annotation is almost the same as define_py_data_sources2, except that
it can specific train_async and data_cls.
:param data_cls:
:param train_list: Train list name.
:type train_list: basestring
:param test_list: Test list name.
:type test_list: basestring
:param module: python module name. If train and test is different, then
pass a tuple or list to this argument.
:type module: basestring or tuple or list
:param obj: python object name. May be a function name if using
PyDataProviderWrapper. If train and test is different, then pass
a tuple or list to this argument.
:type obj: basestring or tuple or list
:param args: The best practice is using dict() to pass arguments into
DataProvider, and use :code:`@init_hook_wrapper` to receive
arguments. If train and test is different, then pass a tuple
or list to this argument.
:type args: string or picklable object or list or tuple.
:param train_async: Is training data load asynchronously or not.
:type train_async: bool
:return: None
:rtype: None
"""
def __is_splitable__(o):
return (isinstance(o, list) or isinstance(o, tuple)
) and hasattr(o, '__len__') and len(o) == 2
assert train_list is not None or test_list is not None
assert module is not None and obj is not None
test_module = module
train_module = module
if __is_splitable__(module):
train_module, test_module = module
test_obj = obj
train_obj = obj
if __is_splitable__(obj):
train_module, test_module = module
if args is None:
args = ""
train_args = args
test_args = args
if __is_splitable__(args):
train_args, test_args = args
if train_list is not None:
define_py_data_source(train_list, TrainData, train_module, train_obj,
train_args, train_async, data_cls)
if test_list is not None:
define_py_data_source(test_list, TestData, test_module, test_obj,
test_args, False, data_cls)
def define_py_data_sources2(train_list, test_list, module, obj, args=None):
"""
Define python Train/Test data sources in one method. If train/test use
the same Data Provider configuration, module/obj/args contain one argument,
otherwise contain a list or tuple of arguments. For example\:
.. code-block:: python
define_py_data_sources2(train_list="train.list",
test_list="test.list",
module="data_provider"
# if train/test use different configurations,
# obj=["process_train", "process_test"]
obj="process",
args={"dictionary": dict_name})
The related data provider can refer to
`here <../../data_provider/pydataprovider2.html#dataprovider-for-the-sequential-model>`__.
:param train_list: Train list name.
:type train_list: basestring
:param test_list: Test list name.
:type test_list: basestring
:param module: python module name. If train and test is different, then
pass a tuple or list to this argument.
:type module: basestring or tuple or list
:param obj: python object name. May be a function name if using
PyDataProviderWrapper. If train and test is different, then pass
a tuple or list to this argument.
:type obj: basestring or tuple or list
:param args: The best practice is using dict() to pass arguments into
DataProvider, and use :code:`@init_hook_wrapper` to receive
arguments. If train and test is different, then pass a tuple
or list to this argument.
:type args: string or picklable object or list or tuple.
:return: None
:rtype: None
"""
define_py_data_sources(train_list=train_list,
test_list=test_list,
module=module,
obj=obj,
args=args,
data_cls=None)