You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
204 lines
7.5 KiB
204 lines
7.5 KiB
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Data Sources are helpers to define paddle training data or testing data.
|
|
"""
|
|
from paddle.trainer.config_parser import *
|
|
from .utils import deprecated
|
|
|
|
try:
|
|
import cPickle as pickle
|
|
except ImportError:
|
|
import pickle
|
|
|
|
__all__ = ['define_py_data_sources2']
|
|
|
|
|
|
def define_py_data_source(file_list, cls, module,
|
|
obj, args=None, async=False,
|
|
data_cls=PyData):
|
|
"""
|
|
Define a python data source.
|
|
|
|
For example, the simplest usage in trainer_config.py as follow:
|
|
|
|
.. code-block:: python
|
|
|
|
define_py_data_source("train.list", TrainData, "data_provider", "process")
|
|
|
|
Or. if you want to pass arguments from trainer_config to data_provider.py, then
|
|
|
|
.. code-block:: python
|
|
|
|
define_py_data_source("train.list", TrainData, "data_provider", "process",
|
|
args={"dictionary": dict_name})
|
|
|
|
:param data_cls:
|
|
:param file_list: file list name, which contains all data file paths
|
|
:type file_list: basestring
|
|
:param cls: Train or Test Class.
|
|
:type cls: TrainData or TestData
|
|
:param module: python module name.
|
|
:type module: basestring
|
|
:param obj: python object name. May be a function name if using
|
|
PyDataProviderWrapper.
|
|
:type obj: basestring
|
|
:param args: The best practice is using dict to pass arguments into
|
|
DataProvider, and use :code:`@init_hook_wrapper` to
|
|
receive arguments.
|
|
:type args: string or picklable object
|
|
:param async: Load Data asynchronously or not.
|
|
:type async: bool
|
|
:return: None
|
|
:rtype: None
|
|
"""
|
|
if isinstance(file_list, list):
|
|
file_list_name = 'train.list'
|
|
if isinstance(cls, TestData):
|
|
file_list_name = 'test.list'
|
|
with open(file_list_name, 'r') as f:
|
|
f.writelines(file_list)
|
|
file_list = file_list_name
|
|
|
|
if not isinstance(args, basestring) and args is not None:
|
|
args = pickle.dumps(args, 0)
|
|
|
|
if data_cls is None:
|
|
def py_data2(files, load_data_module, load_data_object, load_data_args,
|
|
**kwargs):
|
|
data = DataBase()
|
|
data.type = 'py2'
|
|
data.files = files
|
|
data.load_data_module = load_data_module
|
|
data.load_data_object = load_data_object
|
|
data.load_data_args = load_data_args
|
|
return data
|
|
data_cls = py_data2
|
|
|
|
cls(data_cls(files=file_list,
|
|
load_data_module=module,
|
|
load_data_object=obj,
|
|
load_data_args=args,
|
|
async_load_data=async))
|
|
|
|
|
|
def define_py_data_sources(train_list, test_list, module, obj, args=None,
|
|
train_async=False, data_cls=PyData):
|
|
"""
|
|
The annotation is almost the same as define_py_data_sources2, except that
|
|
it can specific train_async and data_cls.
|
|
|
|
:param data_cls:
|
|
:param train_list: Train list name.
|
|
:type train_list: basestring
|
|
:param test_list: Test list name.
|
|
:type test_list: basestring
|
|
:param module: python module name. If train and test is different, then
|
|
pass a tuple or list to this argument.
|
|
:type module: basestring or tuple or list
|
|
:param obj: python object name. May be a function name if using
|
|
PyDataProviderWrapper. If train and test is different, then pass
|
|
a tuple or list to this argument.
|
|
:type obj: basestring or tuple or list
|
|
:param args: The best practice is using dict() to pass arguments into
|
|
DataProvider, and use :code:`@init_hook_wrapper` to receive
|
|
arguments. If train and test is different, then pass a tuple
|
|
or list to this argument.
|
|
:type args: string or picklable object or list or tuple.
|
|
:param train_async: Is training data load asynchronously or not.
|
|
:type train_async: bool
|
|
:return: None
|
|
:rtype: None
|
|
"""
|
|
|
|
def __is_splitable__(o):
|
|
return (isinstance(o, list) or isinstance(o, tuple)
|
|
) and hasattr(o, '__len__') and len(o) == 2
|
|
|
|
assert train_list is not None or test_list is not None
|
|
assert module is not None and obj is not None
|
|
|
|
test_module = module
|
|
train_module = module
|
|
if __is_splitable__(module):
|
|
train_module, test_module = module
|
|
|
|
test_obj = obj
|
|
train_obj = obj
|
|
if __is_splitable__(obj):
|
|
train_module, test_module = module
|
|
|
|
if args is None:
|
|
args = ""
|
|
|
|
train_args = args
|
|
test_args = args
|
|
if __is_splitable__(args):
|
|
train_args, test_args = args
|
|
|
|
if train_list is not None:
|
|
define_py_data_source(train_list, TrainData, train_module, train_obj,
|
|
train_args, train_async, data_cls)
|
|
|
|
if test_list is not None:
|
|
define_py_data_source(test_list, TestData, test_module, test_obj,
|
|
test_args, False, data_cls)
|
|
|
|
|
|
def define_py_data_sources2(train_list, test_list, module, obj, args=None):
|
|
"""
|
|
Define python Train/Test data sources in one method. If train/test use
|
|
the same Data Provider configuration, module/obj/args contain one argument,
|
|
otherwise contain a list or tuple of arguments. For example\:
|
|
|
|
.. code-block:: python
|
|
|
|
define_py_data_sources2(train_list="train.list",
|
|
test_list="test.list",
|
|
module="data_provider"
|
|
# if train/test use different configurations,
|
|
# obj=["process_train", "process_test"]
|
|
obj="process",
|
|
args={"dictionary": dict_name})
|
|
|
|
The related data provider can refer to
|
|
`here <../../data_provider/pydataprovider2.html#dataprovider-for-the-sequential-model>`__.
|
|
|
|
:param train_list: Train list name.
|
|
:type train_list: basestring
|
|
:param test_list: Test list name.
|
|
:type test_list: basestring
|
|
:param module: python module name. If train and test is different, then
|
|
pass a tuple or list to this argument.
|
|
:type module: basestring or tuple or list
|
|
:param obj: python object name. May be a function name if using
|
|
PyDataProviderWrapper. If train and test is different, then pass
|
|
a tuple or list to this argument.
|
|
:type obj: basestring or tuple or list
|
|
:param args: The best practice is using dict() to pass arguments into
|
|
DataProvider, and use :code:`@init_hook_wrapper` to receive
|
|
arguments. If train and test is different, then pass a tuple
|
|
or list to this argument.
|
|
:type args: string or picklable object or list or tuple.
|
|
:return: None
|
|
:rtype: None
|
|
"""
|
|
define_py_data_sources(train_list=train_list,
|
|
test_list=test_list,
|
|
module=module,
|
|
obj=obj,
|
|
args=args,
|
|
data_cls=None)
|