You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
750 lines
27 KiB
750 lines
27 KiB
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
This module provide a wrapper(decorator) to wrap a data process method into a
|
|
PyDataProvider. Some examples are shown `here <data_provider/python_case.html>`_.
|
|
"""
|
|
|
|
import struct
|
|
import array
|
|
import random
|
|
import gc
|
|
import logging
|
|
import pstats
|
|
import sys
|
|
import numpy
|
|
import functools
|
|
|
|
__all__ = [
|
|
'DenseSlot', 'SlotType', 'SparseNonValueSlot', 'StringSlot',
|
|
'SparseValueSlot', 'IndexSlot', 'PoolSize', 'GeneralPyDataProvider',
|
|
'provider', 'init_hook_wrapper'
|
|
]
|
|
|
|
try: # Just for profile mode, will try to import cProfile first.
|
|
# Most python will contains cProfile, cProfile/profile are basically same.
|
|
# ref: https://docs.python.org/2/library/profile.html#introduction-to-the-profilers
|
|
import cProfile as profile
|
|
except ImportError:
|
|
import profile
|
|
|
|
try:
|
|
import cPickle as pickle
|
|
except ImportError:
|
|
import pickle
|
|
|
|
import io
|
|
|
|
|
|
class SlotType(object): # Just a hint for user.
|
|
pass
|
|
|
|
|
|
class DenseSlot(SlotType):
|
|
"""
|
|
Dense Slot Type: Each item is the value of a Dense Vector.
|
|
|
|
Its yield format for :code:`provider` is:
|
|
|
|
- **NonSeq**: [float, float, ... ]
|
|
- **Seq**: [[float, float, ...], [float, float ....], ... ]
|
|
- **SubSeq**: [[[float, float, ...], [float ....], ...] , \
|
|
[[float, float, ...], [float ....], ...] , ...]
|
|
"""
|
|
|
|
def __init__(self, dim):
|
|
"""
|
|
:param dim: slot dimension
|
|
:type dim: int
|
|
"""
|
|
self.dim = dim
|
|
self.type = 0
|
|
|
|
|
|
class SparseNonValueSlot(SlotType):
|
|
"""
|
|
Sparse NonValue Slot Type: Each item is the id of a Sparse Vector.
|
|
|
|
Its yield format for :code:`provider` is:
|
|
|
|
- **NonSeq**: [int, int, ...]
|
|
- **Seq**: [[int, int, ...], [int, int, ...], ... ]
|
|
- **SubSeq**: [[[int, int, ...], [int, ....], ...] , \
|
|
[[int, int, ...], [int, ....], ...] , ...]
|
|
"""
|
|
|
|
def __init__(self, dim):
|
|
"""
|
|
:param dim: slot dimension
|
|
:type dim: int
|
|
"""
|
|
self.dim = dim
|
|
self.type = 1
|
|
|
|
|
|
class SparseValueSlot(SlotType):
|
|
"""
|
|
Sparse Value Slot Type: Each item is the id and value of a Sparse Vector.
|
|
|
|
Its yield format for :code:`provider` is:
|
|
|
|
- **NonSeq**: [(int, float), (int, float), ... ]
|
|
- **Seq**: [[(int,float), (int, float), ... ], \
|
|
[(int, float), (int, float), ...], ... ]
|
|
- **SubSeq**: [[[(int,float), ...], [(int, float), ....], ...] , \
|
|
[[(int,float), ...], [(int, float), ....], ...] , ...]
|
|
"""
|
|
|
|
def __init__(self, dim):
|
|
"""
|
|
:param dim: slot dimension.
|
|
:type dim: int
|
|
"""
|
|
self.dim = dim
|
|
self.type = 2
|
|
|
|
|
|
class IndexSlot(SlotType):
|
|
"""
|
|
Index Value Slot Type: Each item is the id of Label.
|
|
|
|
Its yield format for :code:`provider` is:
|
|
|
|
- **NonSeq**: int
|
|
- **Seq**: [int, int, ....]
|
|
- **SubSeq**: [[int, int, ...], [int, int, ...], ... ]
|
|
"""
|
|
|
|
def __init__(self, dim):
|
|
"""
|
|
:param dim: slot dimension
|
|
:type dim: int
|
|
"""
|
|
self.dim = dim
|
|
self.type = 3
|
|
|
|
|
|
class StringSlot(SlotType):
|
|
"""
|
|
String Value Slot Type: Each item is a string for printout, \
|
|
can be used in DataLayer too.
|
|
|
|
Its yield format for :code:`provider` is:
|
|
|
|
- **NonSeq**: string
|
|
- **Seq**: [string, string, ....]
|
|
- **SubSeq**: [[string, string, ...], [string, string, ...], ... ]
|
|
"""
|
|
|
|
def __init__(self, dim):
|
|
"""
|
|
:param dim: slot dimension
|
|
:type dim: string
|
|
"""
|
|
self.dim = dim
|
|
self.type = 6
|
|
|
|
|
|
class SparseNonValueHandler(object):
|
|
"""
|
|
Private Class, Use for converting python object to paddle string.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.offsets = []
|
|
self.value = []
|
|
self.offset_count = 0
|
|
|
|
def __call__(self, ele):
|
|
"""
|
|
It will be invoked when scan each sparse data.
|
|
|
|
:param ele: list of sparse data, maybe non-value [ idx, ... ] or value.
|
|
[ (idx, val), ... ]
|
|
:type ele: list
|
|
"""
|
|
self.offsets.append(self.offset_count)
|
|
self.offset_count += len(ele)
|
|
self.processElement(ele)
|
|
|
|
def processElement(self, ele):
|
|
"""
|
|
Process for element list. See __call__ for more document.
|
|
"""
|
|
self.value += ele
|
|
|
|
def done(self, data_stream, int_packer):
|
|
"""
|
|
Dump data to stream.
|
|
:param data_stream: Output Stream.
|
|
:param int_packer: A struct.Struct("i") object
|
|
"""
|
|
data_stream.write(array.array("i", self.offsets).tostring())
|
|
data_stream.write(int_packer.pack(self.offset_count))
|
|
data_stream.write(array.array("i", self.value).tostring())
|
|
|
|
|
|
class SparseValueHandler(SparseNonValueHandler):
|
|
"""
|
|
Private class, use for converting python obj to paddle string.
|
|
"""
|
|
|
|
def __init__(self):
|
|
SparseNonValueHandler.__init__(self)
|
|
self.weight = []
|
|
|
|
def processElement(self, ele):
|
|
for idx, w in ele:
|
|
self.value.append(idx)
|
|
self.weight.append(w)
|
|
|
|
def done(self, data_stream, int_packer):
|
|
SparseNonValueHandler.done(self, data_stream, int_packer)
|
|
data_stream.write(int_packer.pack(self.offset_count))
|
|
data_stream.write(array.array("f", self.weight).tostring())
|
|
|
|
|
|
class StringHandler(object):
|
|
"""
|
|
Private Class, Use for converting python object to paddle string.
|
|
"""
|
|
|
|
def __init__(self, data_stream, int_packer):
|
|
self.data_stream = data_stream
|
|
self.int_packer = int_packer
|
|
|
|
def __call__(self, ele):
|
|
"""
|
|
It will be invoked when scan each string data.
|
|
:param ele: string data
|
|
:type ele: str
|
|
"""
|
|
self.data_stream.write(self.int_packer.pack(len(ele)))
|
|
self.data_stream.write(array.array("c", ele).tostring())
|
|
|
|
|
|
class GeneralPyDataProvider:
|
|
def __init__(self, *file_list, **kwargs):
|
|
"""
|
|
:param file_list: input file_list
|
|
"""
|
|
del kwargs # unused
|
|
gc.disable()
|
|
assert isinstance(self.logger, logging.Logger)
|
|
self.use_seq_flag = hasattr(self, "use_seq_flag") and self.use_seq_flag
|
|
self.slots_num = len(self.getSlots())
|
|
self.file_list = list(file_list)
|
|
self.generators = map(self.generateData, self.file_list)
|
|
self.int_packer = struct.Struct("i")
|
|
self.head_packer = struct.Struct("ii")
|
|
self.float_packer = struct.Struct("f")
|
|
self.shuffler = lambda *args, **kwargs: None
|
|
self.data_pool = []
|
|
self.has_subseq = []
|
|
self.has_checked = False
|
|
|
|
self.debug = hasattr(self, "debug") and self.debug
|
|
|
|
if hasattr(self, "profile_filename") and isinstance(
|
|
self.profile_filename, str):
|
|
self.profile_count = 0
|
|
self.is_profile = True
|
|
else:
|
|
self.is_profile = False
|
|
|
|
if not hasattr(self, "file_count") or not isinstance(self.file_count,
|
|
int):
|
|
self.file_count = sys.maxint
|
|
|
|
if not hasattr(self, "can_over_batch_size"):
|
|
self.can_over_batch_size = True
|
|
elif not self.can_over_batch_size:
|
|
self.logger.warn(
|
|
"User should ensure every data size is not larger than batch"
|
|
" size when can_over_batch_size = False")
|
|
|
|
self.data_pool_idx = 0
|
|
|
|
def reset(self):
|
|
"""Reset all data in provider."""
|
|
|
|
self.logger.debug("reset dataprovider.")
|
|
self.generators = map(self.generateData, self.file_list)
|
|
self.shuffler = lambda *args, **kwargs: None
|
|
self.data_pool = []
|
|
self.data_pool_idx = 0
|
|
if self.file_count != 0:
|
|
self.max_pool_size = 0
|
|
|
|
# When use Profile, each pass will print a profile result.
|
|
if self.is_profile:
|
|
if hasattr(self, "profiler") and isinstance(self.profiler,
|
|
profile.Profile):
|
|
self.profiler.disable()
|
|
fn = "%s_%d" % (self.profile_filename, self.profile_count)
|
|
sortby = "cumulative"
|
|
with open(fn, "w") as f:
|
|
pstats.Stats(
|
|
self.profiler,
|
|
stream=f).sort_stats(sortby).print_stats()
|
|
self.logger.info("saving profile to file %s" % fn)
|
|
self.profile_count += 1
|
|
self.logger.info("resetting profile")
|
|
self.profiler = profile.Profile()
|
|
self.profiler.enable()
|
|
|
|
def shuffle(self):
|
|
""" shuffle data"""
|
|
if not self.should_shuffle:
|
|
return
|
|
else:
|
|
self.logger.debug("shuffling data.")
|
|
random.shuffle(self.generators)
|
|
self.shuffler = random.shuffle
|
|
|
|
def getSlots(self):
|
|
"""
|
|
:return : return a list of SlotType
|
|
:rtype: list
|
|
"""
|
|
return []
|
|
|
|
def generateData(self, fn):
|
|
"""
|
|
:param fn: file name
|
|
:return: a generator to yield data one by one.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def calculateDataBatchSize(self, data):
|
|
"""
|
|
:param data: One sample which yield by generateData
|
|
:type data: list
|
|
:return: The batch size that the data contribute.
|
|
:rtype: int
|
|
"""
|
|
return 1
|
|
|
|
def getHeader(self):
|
|
"""return paddle header format"""
|
|
ret = self.head_packer.pack(self.slots_num, self.use_seq_flag)
|
|
for obj in self.getSlots():
|
|
ret += self.head_packer.pack(obj.type, obj.dim)
|
|
return ret
|
|
|
|
def getHeaderNative(self):
|
|
return self.use_seq_flag, self.getSlots()
|
|
|
|
def getNextBatchNative(self, batch_size):
|
|
ret_list = []
|
|
self.__prepareData(batch_size, ret_list)
|
|
return ret_list
|
|
|
|
def getNextBatch(self, batch_size):
|
|
"""
|
|
:param batch_size: the batch_size approximately return.
|
|
:return: return paddle pyDataProvider format, just see documents.
|
|
:rtype: str
|
|
|
|
NOTE: If can_over_batch_size is True, the return batch_size >= input batch_size.
|
|
Otherwise, the return batch_size < input batch_size, BUT USER MUST ENSURE THAT each data's batch size
|
|
is less than input batch_size.
|
|
"""
|
|
ret_list = []
|
|
current_batch_size = self.__prepareData(batch_size, ret_list)
|
|
# create unified format for ret_list with differnt slots_num
|
|
if self.slots_num == 1:
|
|
ret_list = [ret_list]
|
|
|
|
if current_batch_size == 0:
|
|
return self.int_packer.pack(current_batch_size)
|
|
data_bytes = io.BytesIO()
|
|
seq_bytes = io.BytesIO()
|
|
subseq_bytes = io.BytesIO()
|
|
data_stream = io.BufferedWriter(data_bytes)
|
|
seq_stream = io.BufferedWriter(seq_bytes)
|
|
subseq_stream = io.BufferedWriter(subseq_bytes)
|
|
|
|
def convertDataImpl(idx, data_callback):
|
|
"""
|
|
This method will handle sequence in return data. invoke data_callback one by one.
|
|
:param idx: the slot index.
|
|
:param data_callback: a callback, which type is (each sample) => None.
|
|
"""
|
|
indices = 0
|
|
slot_sample_num = len(ret_list)
|
|
if self.use_seq_flag:
|
|
slot_sample_num = 0
|
|
if self.has_subseq[idx]: # has sub-sequence
|
|
slot_subseq_num = 0
|
|
for dat in ret_list:
|
|
dat = dat[idx]
|
|
slot_subseq_num += len(dat)
|
|
for sub_dat in dat:
|
|
slot_sample_num += len(sub_dat)
|
|
subseq_stream.write(self.int_packer.pack(slot_subseq_num))
|
|
else:
|
|
for dat in ret_list:
|
|
dat = dat[idx]
|
|
slot_sample_num += len(dat)
|
|
seq_stream.write(self.int_packer.pack(len(ret_list)))
|
|
data_stream.write(self.int_packer.pack(slot_sample_num))
|
|
|
|
for dat in ret_list:
|
|
dat = dat[idx]
|
|
if self.use_seq_flag:
|
|
seq_stream.write(self.int_packer.pack(indices))
|
|
if self.has_subseq[idx]: # has sub-sequence
|
|
for sub_dat in dat:
|
|
writeDataStream(sub_dat, data_callback)
|
|
subseq_stream.write(self.int_packer.pack(indices))
|
|
indices += len(sub_dat)
|
|
else:
|
|
writeDataStream(dat, data_callback)
|
|
indices += len(dat)
|
|
else:
|
|
writeDataStream(dat, data_callback)
|
|
|
|
def writeDataStream(dat, data_callback):
|
|
if self.use_seq_flag > 0:
|
|
if data_callback is None: # Special for index slot
|
|
data_stream.write(array.array("i", dat).tostring())
|
|
else:
|
|
for ele in dat:
|
|
data_callback(ele)
|
|
else:
|
|
if data_callback is None: # Special for index slot
|
|
data_stream.write(self.int_packer.pack(dat))
|
|
else:
|
|
data_callback(dat)
|
|
|
|
try:
|
|
for i in range(self.slots_num):
|
|
slot = self.getSlots()[i]
|
|
# According to the data_type, each slot data will be converted to binary
|
|
if isinstance(slot, DenseSlot):
|
|
convertDataImpl(i, lambda e: data_stream.write(
|
|
array.array("f", e).tostring()))
|
|
elif isinstance(slot, SparseNonValueSlot):
|
|
handler = SparseNonValueHandler()
|
|
convertDataImpl(i, handler)
|
|
handler.done(data_stream, self.int_packer)
|
|
elif isinstance(slot, SparseValueSlot):
|
|
handler = SparseValueHandler()
|
|
convertDataImpl(i, handler)
|
|
handler.done(data_stream, self.int_packer)
|
|
elif isinstance(slot, IndexSlot):
|
|
convertDataImpl(i, None)
|
|
elif isinstance(slot, StringSlot):
|
|
handler = StringHandler(data_stream, self.int_packer)
|
|
convertDataImpl(i, handler)
|
|
else:
|
|
raise RuntimeError("The data_type must be 0/1/2/3/6")
|
|
data_stream.flush()
|
|
seq_stream.flush()
|
|
subseq_stream.flush()
|
|
|
|
return "".join([
|
|
self.int_packer.pack(current_batch_size), data_bytes.getvalue(),
|
|
seq_bytes.getvalue(), subseq_bytes.getvalue()
|
|
])
|
|
|
|
finally:
|
|
data_stream.close()
|
|
seq_stream.close()
|
|
subseq_stream.close()
|
|
data_bytes.close()
|
|
seq_bytes.close()
|
|
subseq_bytes.close()
|
|
|
|
def hasSubseq(self, ret_list):
|
|
# create unified format for ret_list with differnt slots_num
|
|
if self.slots_num == 1:
|
|
ret_list = [ret_list]
|
|
# decide whether slot has sub-sequence using its first sample
|
|
for i in range(self.slots_num):
|
|
slot = self.getSlots()[i]
|
|
dat = ret_list[0][i][0]
|
|
if isinstance(slot, IndexSlot) or isinstance(slot, StringSlot):
|
|
if isinstance(dat, list) or isinstance(dat, numpy.ndarray):
|
|
self.has_subseq.append(1) # has_subseq = True
|
|
continue
|
|
elif isinstance(dat[0], list) or isinstance(dat[0], numpy.ndarray):
|
|
self.has_subseq.append(1) # has_subseq = True
|
|
continue
|
|
self.has_subseq.append(0) # has_subseq = False
|
|
|
|
def checkOrder(self):
|
|
first_noSubseq_slot = self.slots_num
|
|
last_subseq_slot = -1
|
|
for i in range(self.slots_num):
|
|
if not self.has_subseq[i]:
|
|
first_noSubseq_slot = i
|
|
break
|
|
for i in range(self.slots_num):
|
|
if self.has_subseq[i]:
|
|
last_subseq_slot = i
|
|
if first_noSubseq_slot < last_subseq_slot:
|
|
raise RuntimeError(
|
|
"slot hasSubseq must put before than slot without subseq")
|
|
self.has_checked = True
|
|
|
|
def __prepareData(self, batch_size, ret_list):
|
|
current_batch_size = 0
|
|
could_exit = False
|
|
while not could_exit:
|
|
if len(self.data_pool) == 0:
|
|
self.data_pool_idx = 0
|
|
self.fillPool()
|
|
if len(self.data_pool) != 0:
|
|
for idx in xrange(self.data_pool_idx, len(self.data_pool)):
|
|
current_batch_size += self.calculateDataBatchSize(
|
|
self.data_pool[idx])
|
|
if current_batch_size >= batch_size:
|
|
could_exit = True
|
|
break
|
|
if current_batch_size > batch_size and not self.can_over_batch_size: # if cannot over batch size
|
|
current_batch_size -= self.calculateDataBatchSize(
|
|
self.data_pool[idx])
|
|
idx -= 1
|
|
|
|
ret_list += self.data_pool[self.data_pool_idx:idx + 1]
|
|
|
|
# for speed reason, just shift left index, not delete data actually.
|
|
self.data_pool_idx = idx + 1
|
|
|
|
if self.data_pool_idx == len(self.data_pool):
|
|
self.data_pool = []
|
|
else:
|
|
break
|
|
if self.use_seq_flag and not self.has_checked: # compute self.has_subseq and checkOrder only at first time
|
|
self.hasSubseq(ret_list)
|
|
self.checkOrder()
|
|
return current_batch_size
|
|
|
|
def fillPool(self):
|
|
"""
|
|
Fill the pool to max_pool_size. If max_pool_size is None, then read file_count to pool.
|
|
"""
|
|
if self.max_pool_size == 0:
|
|
for i in xrange(min(self.file_count, len(self.generators))):
|
|
self.data_pool += list(self.generators[i])
|
|
self.generators = self.generators[min(self.file_count,
|
|
len(self.generators)):]
|
|
self.max_pool_size = len(self.data_pool)
|
|
else:
|
|
while len(self.data_pool) < self.max_pool_size and len(
|
|
self.generators) != 0:
|
|
try:
|
|
self.data_pool.append(self.generators[0].next())
|
|
except StopIteration:
|
|
self.generators.pop(0)
|
|
self.shuffler(self.data_pool)
|
|
|
|
|
|
class PoolSize(object):
|
|
"""Max number of sample which contains in provider."""
|
|
|
|
def __init__(self, pool_size):
|
|
self.size = pool_size
|
|
|
|
|
|
def default_init_hook(cls, *args, **kwargs):
|
|
""" default hook, do nothing """
|
|
del cls, args, kwargs
|
|
|
|
|
|
def provider(slots=None,
|
|
use_seq=False,
|
|
should_shuffle=True,
|
|
pool_size=1,
|
|
can_over_batch_size=True,
|
|
calc_batch_size=lambda data: 1,
|
|
debug=False,
|
|
init_hook=default_init_hook,
|
|
profile_filename=None):
|
|
"""
|
|
The decorator for PyDataProvider. User should use this to create Provider class.
|
|
User should only concern how to read sample from file.
|
|
|
|
So the basic usage is:
|
|
|
|
.. code-block:: python
|
|
|
|
@provider(some data provider config here...)
|
|
def process(obj, file_name):
|
|
while not at end of file_name:
|
|
sample = readOneSampleFromFile(file_name)
|
|
yield sample.
|
|
|
|
The configuration of data provider should be setup by:
|
|
|
|
:param init_hook: A callback will be invoked when PyDataProvider instance \
|
|
created. The parameter is (obj, \*args, \*\*kwargs).
|
|
|
|
- **obj**: actually data provider instance, which \
|
|
contains some global objects in obj.xxxxx, \
|
|
and is used by process function.
|
|
|
|
1. **obj.slots**: a list of SlotType Object. Can be \
|
|
set in init. For example, obj.slots = \
|
|
[DenseSlot(9), IndexSlot(2)].
|
|
2. **obj.logger**: a logger object. User can invoke \
|
|
obj.logger.info(), obj.logger.fatal(), etc.
|
|
|
|
- **args** and **kwargs**: the data provider __init__ \
|
|
parameters. For example, load_data_args \
|
|
will be found in \*\*kwargs, \
|
|
and if you want to recieve \
|
|
it from trainer_config, \
|
|
recommand to use init_hook_wrapper
|
|
:type init_hook: callable
|
|
|
|
:param pool_size:
|
|
- **int**: it will read at most pool_size files to memory.
|
|
- **PoolSize**: it will read at most PoolSize.size samples to memory.
|
|
- If not set, it will read all the files to memory.
|
|
:type pool_size: int | PoolSize
|
|
|
|
:param slots: Specify the SlotTypes, can also be set in init_hook. It has two formats:
|
|
|
|
- A list of SlotType objects. For example, slots = \
|
|
[DenseSlot(9), IndexSlot(2)].
|
|
- A method return a list of SlotTypes, and the parameter of \
|
|
method is (obj, \*file_list, \*\*kwargs).
|
|
:type slots: list | callable
|
|
|
|
:param use_seq: False if use no sequence (Default). True if use sequence:
|
|
|
|
- If sequence has **no sub-sequence**: Each slot will \
|
|
return a list of data. This list is one sequence. \
|
|
So the return format likes \
|
|
[[a0, a1, a2], [b1, b2, b3, b4], [c1]].
|
|
- If sequence has **sub-sequence**: Each slot will return \
|
|
a nested-list of data. This list contains several \
|
|
sub-lists, each sub-list is one sub-sequence. \
|
|
So the return format likes \
|
|
[[[a0, a1, a2], [a4, a5]], [[b1, b2, b3, b4], [b5, b6]], [[c1], [c2]]].
|
|
:type use_seq: bool
|
|
|
|
:param should_shuffle: True if data should shuffle.
|
|
:type should_shuffle: bool
|
|
|
|
:param calc_batch_size: The method calculate each data's batch size.
|
|
|
|
- Default is the batch size of one sample.
|
|
- User can customize by **lamda** funtion. For example, \
|
|
:code:`calc_batch_size = lambda data : len(data)` \
|
|
means calculating the token number of a sequence data.
|
|
:type calc_batch_size: callable
|
|
|
|
:param can_over_batch_size: Whether :code:`actual batch size >= input batch size`
|
|
|
|
- **True** (>=): getNextBatch method can return more data (Default).
|
|
- **False** (<): user must ensure that each data's batch size < input batch size.
|
|
:type can_over_batch_size: bool
|
|
|
|
:param debug: True if enable debug logger and some debug check. Default is False.
|
|
:type debug: bool
|
|
|
|
:param profile_filename: None if disable profile (Default). Otherwise, \
|
|
the data provider will dump profile result when \
|
|
reset. And the dump filename is \
|
|
**<profile_filename>_<reset_count>**.
|
|
:type profile_filename: None | Str
|
|
"""
|
|
|
|
def _wrapper(handler):
|
|
class Cls(GeneralPyDataProvider):
|
|
""" Real PyDataProvider Class. """
|
|
|
|
def __init__(self, *file_list, **kwargs):
|
|
logging.basicConfig(
|
|
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
|
|
" %(message)s")
|
|
|
|
self.logger = logging.getLogger("")
|
|
if debug:
|
|
self.logger.setLevel(logging.DEBUG)
|
|
self.logger.debug("Running pydataprovider in debug mode.")
|
|
else:
|
|
self.logger.setLevel(logging.INFO)
|
|
|
|
init_hook(self, *file_list, **kwargs)
|
|
if callable(slots):
|
|
self.slots = slots(self, *file_list, **kwargs)
|
|
elif slots is not None:
|
|
self.slots = slots
|
|
|
|
if isinstance(pool_size, int):
|
|
self.max_pool_size = 0
|
|
self.file_count = pool_size
|
|
elif isinstance(pool_size, PoolSize):
|
|
self.max_pool_size = pool_size.size
|
|
self.file_count = 0
|
|
else:
|
|
raise RuntimeError
|
|
self.can_over_batch_size = can_over_batch_size
|
|
self.debug = debug
|
|
self.profile_filename = profile_filename
|
|
self.use_seq_flag = use_seq
|
|
self.should_shuffle = should_shuffle
|
|
GeneralPyDataProvider.__init__(self, *file_list, **kwargs)
|
|
|
|
def getSlots(self):
|
|
return self.slots
|
|
|
|
def generateData(self, f):
|
|
return handler(self, f)
|
|
|
|
def calculateDataBatchSize(self, data):
|
|
return calc_batch_size(data)
|
|
|
|
return Cls
|
|
|
|
return _wrapper
|
|
|
|
|
|
def init_hook_wrapper(func):
|
|
"""
|
|
Wrap a method for PyDataProviderWrapper's init_hook. This method can
|
|
receive parameter from trainer_config's load_data_args. The load_data_args
|
|
must pass a pickle.dumps() value, and dump a map as keyword args. The
|
|
wrapped method :code:`func` will receive them as keyword args.
|
|
|
|
So an example usage is:
|
|
|
|
.. code-block:: python
|
|
|
|
@init_hook_wrapper
|
|
def hook(obj, dictionary, file_list, **kwargs):
|
|
obj.dictionary = dictionary
|
|
obj.slots = [IndexSlot(len(obj.dictionary)),
|
|
IndexSlot(len(open(file_list[0], "r").readlines()))]
|
|
|
|
:param func: init_hook function
|
|
:type func: callable
|
|
:return: wrapped method, can be passed into @provider.
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(obj, *file_list, **kwargs):
|
|
args = kwargs.get("load_data_args", dict())
|
|
if isinstance(args, basestring):
|
|
args = pickle.loads(args)
|
|
args['file_list'] = file_list
|
|
func(obj=obj, **args)
|
|
|
|
return wrapper
|