You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/trainer/PyDataProviderWrapper.py

741 lines
26 KiB

# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provide a wrapper(decorator) to wrap a data process method into a
PyDataProvider. Some examples are shown `here <data_provider/python_case.html>`_.
"""
import struct
import array
import random
import gc
import logging
import pstats
import sys
import numpy
import functools
__all__ = [
'DenseSlot', 'SlotType', 'SparseNonValueSlot', 'StringSlot',
'SparseValueSlot', 'IndexSlot', 'PoolSize', 'GeneralPyDataProvider',
'provider', 'init_hook_wrapper'
]
try: # Just for profile mode, will try to import cProfile first.
# Most python will contains cProfile, cProfile/profile are basically same.
# ref: https://docs.python.org/2/library/profile.html#introduction-to-the-profilers
import cProfile as profile
except ImportError:
import profile
try:
import cPickle as pickle
except ImportError:
import pickle
import io
class SlotType(object): # Just a hint for user.
pass
class DenseSlot(SlotType):
"""
Dense Slot Type: Each item is the value of a Dense Vector.
Its yield format for :code:`provider` is:
- **NonSeq**: [float, float, ... ]
- **Seq**: [[float, float, ...], [float, float ....], ... ]
- **SubSeq**: [[[float, float, ...], [float ....], ...] , \
[[float, float, ...], [float ....], ...] , ...]
"""
def __init__(self, dim):
"""
:param dim: slot dimension
:type dim: int
"""
self.dim = dim
self.type = 0
class SparseNonValueSlot(SlotType):
"""
Sparse NonValue Slot Type: Each item is the id of a Sparse Vector.
Its yield format for :code:`provider` is:
- **NonSeq**: [int, int, ...]
- **Seq**: [[int, int, ...], [int, int, ...], ... ]
- **SubSeq**: [[[int, int, ...], [int, ....], ...] , \
[[int, int, ...], [int, ....], ...] , ...]
"""
def __init__(self, dim):
"""
:param dim: slot dimension
:type dim: int
"""
self.dim = dim
self.type = 1
class SparseValueSlot(SlotType):
"""
Sparse Value Slot Type: Each item is the id and value of a Sparse Vector.
Its yield format for :code:`provider` is:
- **NonSeq**: [(int, float), (int, float), ... ]
- **Seq**: [[(int,float), (int, float), ... ], \
[(int, float), (int, float), ...], ... ]
- **SubSeq**: [[[(int,float), ...], [(int, float), ....], ...] , \
[[(int,float), ...], [(int, float), ....], ...] , ...]
"""
def __init__(self, dim):
"""
:param dim: slot dimension.
:type dim: int
"""
self.dim = dim
self.type = 2
class IndexSlot(SlotType):
"""
Index Value Slot Type: Each item is the id of Label.
Its yield format for :code:`provider` is:
- **NonSeq**: int
- **Seq**: [int, int, ....]
- **SubSeq**: [[int, int, ...], [int, int, ...], ... ]
"""
def __init__(self, dim):
"""
:param dim: slot dimension
:type dim: int
"""
self.dim = dim
self.type = 3
class StringSlot(SlotType):
"""
String Value Slot Type: Each item is a string for printout, \
can be used in DataLayer too.
Its yield format for :code:`provider` is:
- **NonSeq**: string
- **Seq**: [string, string, ....]
- **SubSeq**: [[string, string, ...], [string, string, ...], ... ]
"""
def __init__(self, dim):
"""
:param dim: slot dimension
:type dim: string
"""
self.dim = dim
self.type = 6
class SparseNonValueHandler(object):
"""
Private Class, Use for converting python object to paddle string.
"""
def __init__(self):
self.offsets = []
self.value = []
self.offset_count = 0
def __call__(self, ele):
"""
It will be invoked when scan each sparse data.
:param ele: list of sparse data, maybe non-value [ idx, ... ] or value.
[ (idx, val), ... ]
:type ele: list
"""
self.offsets.append(self.offset_count)
self.offset_count += len(ele)
self.processElement(ele)
def processElement(self, ele):
"""
Process for element list. See __call__ for more document.
"""
self.value += ele
def done(self, data_stream, int_packer):
"""
Dump data to stream.
:param data_stream: Output Stream.
:param int_packer: A struct.Struct("i") object
"""
data_stream.write(array.array("i", self.offsets).tostring())
data_stream.write(int_packer.pack(self.offset_count))
data_stream.write(array.array("i", self.value).tostring())
class SparseValueHandler(SparseNonValueHandler):
"""
Private class, use for converting python obj to paddle string.
"""
def __init__(self):
SparseNonValueHandler.__init__(self)
self.weight = []
def processElement(self, ele):
for idx, w in ele:
self.value.append(idx)
self.weight.append(w)
def done(self, data_stream, int_packer):
SparseNonValueHandler.done(self, data_stream, int_packer)
data_stream.write(int_packer.pack(self.offset_count))
data_stream.write(array.array("f", self.weight).tostring())
class StringHandler(object):
"""
Private Class, Use for converting python object to paddle string.
"""
def __init__(self, data_stream, int_packer):
self.data_stream = data_stream
self.int_packer = int_packer
def __call__(self, ele):
"""
It will be invoked when scan each string data.
:param ele: string data
:type ele: str
"""
self.data_stream.write(self.int_packer.pack(len(ele)))
self.data_stream.write(array.array("c", ele).tostring())
class GeneralPyDataProvider:
def __init__(self, *file_list, **kwargs):
"""
:param file_list: input file_list
"""
del kwargs # unused
gc.disable()
assert isinstance(self.logger, logging.Logger)
self.use_seq_flag = hasattr(self, "use_seq_flag") and self.use_seq_flag
self.slots_num = len(self.getSlots())
self.file_list = list(file_list)
self.generators = map(self.generateData, self.file_list)
self.int_packer = struct.Struct("i")
self.head_packer = struct.Struct("ii")
self.float_packer = struct.Struct("f")
self.shuffler = lambda *args, **kwargs: None
self.data_pool = []
self.has_subseq = []
self.has_checked = False
self.debug = hasattr(self, "debug") and self.debug
if hasattr(self, "profile_filename") and isinstance(
self.profile_filename, str):
self.profile_count = 0
self.is_profile = True
else:
self.is_profile = False
if not hasattr(self, "file_count") or not isinstance(self.file_count,
int):
self.file_count = sys.maxint
if not hasattr(self, "can_over_batch_size"):
self.can_over_batch_size = True
elif not self.can_over_batch_size:
self.logger.warn(
"User should ensure every data size is not larger than batch"
" size when can_over_batch_size = False")
self.data_pool_idx = 0
def reset(self):
"""Reset all data in provider."""
self.logger.debug("reset dataprovider.")
self.generators = map(self.generateData, self.file_list)
self.shuffler = lambda *args, **kwargs: None
self.data_pool = []
self.data_pool_idx = 0
if self.file_count != 0:
self.max_pool_size = 0
# When use Profile, each pass will print a profile result.
if self.is_profile:
if hasattr(self, "profiler") and isinstance(self.profiler,
profile.Profile):
self.profiler.disable()
fn = "%s_%d" % (self.profile_filename, self.profile_count)
sortby = "cumulative"
with open(fn, "w") as f:
pstats.Stats(self.profiler, stream=f).sort_stats(
sortby).print_stats()
self.logger.info("saving profile to file %s" % fn)
self.profile_count += 1
self.logger.info("resetting profile")
self.profiler = profile.Profile()
self.profiler.enable()
def shuffle(self):
""" shuffle data"""
if not self.should_shuffle:
return
else:
self.logger.debug("shuffling data.")
random.shuffle(self.generators)
self.shuffler = random.shuffle
def getSlots(self):
"""
:return : return a list of SlotType
:rtype: list
"""
return []
def generateData(self, fn):
"""
:param fn: file name
:return: a generator to yield data one by one.
"""
raise NotImplementedError
def calculateDataBatchSize(self, data):
"""
:param data: One sample which yield by generateData
:type data: list
:return: The batch size that the data contribute.
:rtype: int
"""
return 1
def getHeader(self):
"""return paddle header format"""
ret = self.head_packer.pack(self.slots_num, self.use_seq_flag)
for obj in self.getSlots():
ret += self.head_packer.pack(obj.type, obj.dim)
return ret
def getHeaderNative(self):
return self.use_seq_flag, self.getSlots()
def getNextBatchNative(self, batch_size):
ret_list = []
self.__prepareData(batch_size, ret_list)
return ret_list
def getNextBatch(self, batch_size):
"""
:param batch_size: the batch_size approximately return.
:return: return paddle pyDataProvider format, just see documents.
:rtype: str
NOTE: If can_over_batch_size is True, the return batch_size >= input batch_size.
Otherwise, the return batch_size < input batch_size, BUT USER MUST ENSURE THAT each data's batch size
is less than input batch_size.
"""
ret_list = []
current_batch_size = self.__prepareData(batch_size, ret_list)
# create unified format for ret_list with differnt slots_num
if self.slots_num == 1:
ret_list = [ret_list]
if current_batch_size == 0:
return self.int_packer.pack(current_batch_size)
data_bytes = io.BytesIO()
seq_bytes = io.BytesIO()
subseq_bytes = io.BytesIO()
data_stream = io.BufferedWriter(data_bytes)
seq_stream = io.BufferedWriter(seq_bytes)
subseq_stream = io.BufferedWriter(subseq_bytes)
def convertDataImpl(idx, data_callback):
"""
This method will handle sequence in return data. invoke data_callback one by one.
:param idx: the slot index.
:param data_callback: a callback, which type is (each sample) => None.
"""
indices = 0
slot_sample_num = len(ret_list)
if self.use_seq_flag:
slot_sample_num = 0
if self.has_subseq[idx]: # has sub-sequence
slot_subseq_num = 0
for dat in ret_list:
dat = dat[idx]
slot_subseq_num += len(dat)
for sub_dat in dat:
slot_sample_num += len(sub_dat)
subseq_stream.write(self.int_packer.pack(slot_subseq_num))
else:
for dat in ret_list:
dat = dat[idx]
slot_sample_num += len(dat)
seq_stream.write(self.int_packer.pack(len(ret_list)))
data_stream.write(self.int_packer.pack(slot_sample_num))
for dat in ret_list:
dat = dat[idx]
if self.use_seq_flag:
seq_stream.write(self.int_packer.pack(indices))
if self.has_subseq[idx]: # has sub-sequence
for sub_dat in dat:
writeDataStream(sub_dat, data_callback)
subseq_stream.write(self.int_packer.pack(indices))
indices += len(sub_dat)
else:
writeDataStream(dat, data_callback)
indices += len(dat)
else:
writeDataStream(dat, data_callback)
def writeDataStream(dat, data_callback):
if self.use_seq_flag > 0:
if data_callback is None: # Special for index slot
data_stream.write(array.array("i", dat).tostring())
else:
for ele in dat:
data_callback(ele)
else:
if data_callback is None: # Special for index slot
data_stream.write(self.int_packer.pack(dat))
else:
data_callback(dat)
try:
for i in range(self.slots_num):
slot = self.getSlots()[i]
# According to the data_type, each slot data will be converted to binary
if isinstance(slot, DenseSlot):
convertDataImpl(i, lambda e: data_stream.write(
array.array("f", e).tostring()))
elif isinstance(slot, SparseNonValueSlot):
handler = SparseNonValueHandler()
convertDataImpl(i, handler)
handler.done(data_stream, self.int_packer)
elif isinstance(slot, SparseValueSlot):
handler = SparseValueHandler()
convertDataImpl(i, handler)
handler.done(data_stream, self.int_packer)
elif isinstance(slot, IndexSlot):
convertDataImpl(i, None)
elif isinstance(slot, StringSlot):
handler = StringHandler(data_stream, self.int_packer)
convertDataImpl(i, handler)
else:
raise RuntimeError("The data_type must be 0/1/2/3/6")
data_stream.flush()
seq_stream.flush()
subseq_stream.flush()
return "".join([self.int_packer.pack(current_batch_size),
data_bytes.getvalue(),
seq_bytes.getvalue(), subseq_bytes.getvalue()])
finally:
data_stream.close()
seq_stream.close()
subseq_stream.close()
data_bytes.close()
seq_bytes.close()
subseq_bytes.close()
def hasSubseq(self, ret_list):
# create unified format for ret_list with differnt slots_num
if self.slots_num == 1:
ret_list = [ret_list]
# decide whether slot has sub-sequence using its first sample
for i in range(self.slots_num):
slot = self.getSlots()[i]
dat = ret_list[0][i][0]
if isinstance(slot, IndexSlot) or isinstance(slot, StringSlot):
if isinstance(dat, list) or isinstance(dat, numpy.ndarray):
self.has_subseq.append(1) # has_subseq = True
continue
elif isinstance(dat[0], list) or isinstance(dat[0], numpy.ndarray):
self.has_subseq.append(1) # has_subseq = True
continue
self.has_subseq.append(0) # has_subseq = False
def checkOrder(self):
first_noSubseq_slot = self.slots_num
last_subseq_slot = -1
for i in range(self.slots_num):
if not self.has_subseq[i]:
first_noSubseq_slot = i
break
for i in range(self.slots_num):
if self.has_subseq[i]:
last_subseq_slot = i
if first_noSubseq_slot < last_subseq_slot:
raise RuntimeError(
"slot hasSubseq must put before than slot without subseq")
self.has_checked = True
def __prepareData(self, batch_size, ret_list):
current_batch_size = 0
could_exit = False
while not could_exit:
if len(self.data_pool) == 0:
self.data_pool_idx = 0
self.fillPool()
if len(self.data_pool) != 0:
for idx in xrange(self.data_pool_idx, len(self.data_pool)):
current_batch_size += self.calculateDataBatchSize(
self.data_pool[idx])
if current_batch_size >= batch_size:
could_exit = True
break
if current_batch_size > batch_size and not self.can_over_batch_size: # if cannot over batch size
current_batch_size -= self.calculateDataBatchSize(
self.data_pool[idx])
idx -= 1
ret_list += self.data_pool[self.data_pool_idx: idx + 1]
# for speed reason, just shift left index, not delete data actually.
self.data_pool_idx = idx + 1
if self.data_pool_idx == len(self.data_pool):
self.data_pool = []
else:
break
if self.use_seq_flag and not self.has_checked: # compute self.has_subseq and checkOrder only at first time
self.hasSubseq(ret_list)
self.checkOrder()
return current_batch_size
def fillPool(self):
"""
Fill the pool to max_pool_size. If max_pool_size is None, then read file_count to pool.
"""
if self.max_pool_size == 0:
for i in xrange(min(self.file_count, len(self.generators))):
self.data_pool += list(self.generators[i])
self.generators = self.generators[
min(self.file_count, len(self.generators)):]
self.max_pool_size = len(self.data_pool)
else:
while len(self.data_pool) < self.max_pool_size and len(
self.generators) != 0:
try:
self.data_pool.append(self.generators[0].next())
except StopIteration:
self.generators.pop(0)
self.shuffler(self.data_pool)
class PoolSize(object):
"""Max number of sample which contains in provider."""
def __init__(self, pool_size):
self.size = pool_size
def default_init_hook(cls, *args, **kwargs):
""" default hook, do nothing """
del cls, args, kwargs
def provider(slots=None, use_seq=False, should_shuffle=True, pool_size=1,
can_over_batch_size=True, calc_batch_size=lambda data: 1,
debug=False, init_hook=default_init_hook, profile_filename=None):
"""
The decorator for PyDataProvider. User should use this to create Provider class.
User should only concern how to read sample from file.
So the basic usage is:
.. code-block:: python
@provider(some data provider config here...)
def process(obj, file_name):
while not at end of file_name:
sample = readOneSampleFromFile(file_name)
yield sample.
The configuration of data provider should be setup by:
:param init_hook: A callback will be invoked when PyDataProvider instance \
created. The parameter is (obj, \*args, \*\*kwargs).
- **obj**: actually data provider instance, which \
contains some global objects in obj.xxxxx, \
and is used by process function.
1. **obj.slots**: a list of SlotType Object. Can be \
set in init. For example, obj.slots = \
[DenseSlot(9), IndexSlot(2)].
2. **obj.logger**: a logger object. User can invoke \
obj.logger.info(), obj.logger.fatal(), etc.
- **args** and **kwargs**: the data provider __init__ \
parameters. For example, load_data_args \
will be found in \*\*kwargs, \
and if you want to recieve \
it from trainer_config, \
recommand to use init_hook_wrapper
:type init_hook: callable
:param pool_size:
- **int**: it will read at most pool_size files to memory.
- **PoolSize**: it will read at most PoolSize.size samples to memory.
- If not set, it will read all the files to memory.
:type pool_size: int | PoolSize
:param slots: Specify the SlotTypes, can also be set in init_hook. It has two formats:
- A list of SlotType objects. For example, slots = \
[DenseSlot(9), IndexSlot(2)].
- A method return a list of SlotTypes, and the parameter of \
method is (obj, \*file_list, \*\*kwargs).
:type slots: list | callable
:param use_seq: False if use no sequence (Default). True if use sequence:
- If sequence has **no sub-sequence**: Each slot will \
return a list of data. This list is one sequence. \
So the return format likes \
[[a0, a1, a2], [b1, b2, b3, b4], [c1]].
- If sequence has **sub-sequence**: Each slot will return \
a nested-list of data. This list contains several \
sub-lists, each sub-list is one sub-sequence. \
So the return format likes \
[[[a0, a1, a2], [a4, a5]], [[b1, b2, b3, b4], [b5, b6]], [[c1], [c2]]].
:type use_seq: bool
:param should_shuffle: True if data should shuffle.
:type should_shuffle: bool
:param calc_batch_size: The method calculate each data's batch size.
- Default is the batch size of one sample.
- User can customize by **lamda** funtion. For example, \
:code:`calc_batch_size = lambda data : len(data)` \
means calculating the token number of a sequence data.
:type calc_batch_size: callable
:param can_over_batch_size: Whether :code:`actual batch size >= input batch size`
- **True** (>=): getNextBatch method can return more data (Default).
- **False** (<): user must ensure that each data's batch size < input batch size.
:type can_over_batch_size: bool
:param debug: True if enable debug logger and some debug check. Default is False.
:type debug: bool
:param profile_filename: None if disable profile (Default). Otherwise, \
the data provider will dump profile result when \
reset. And the dump filename is \
**<profile_filename>_<reset_count>**.
:type profile_filename: None | Str
"""
def _wrapper(handler):
class Cls(GeneralPyDataProvider):
""" Real PyDataProvider Class. """
def __init__(self, *file_list, **kwargs):
logging.basicConfig(
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")
self.logger = logging.getLogger("")
if debug:
self.logger.setLevel(logging.DEBUG)
self.logger.debug("Running pydataprovider in debug mode.")
else:
self.logger.setLevel(logging.INFO)
init_hook(self, *file_list, **kwargs)
if callable(slots):
self.slots = slots(self, *file_list, **kwargs)
elif slots is not None:
self.slots = slots
if isinstance(pool_size, int):
self.max_pool_size = 0
self.file_count = pool_size
elif isinstance(pool_size, PoolSize):
self.max_pool_size = pool_size.size
self.file_count = 0
else:
raise RuntimeError
self.can_over_batch_size = can_over_batch_size
self.debug = debug
self.profile_filename = profile_filename
self.use_seq_flag = use_seq
self.should_shuffle = should_shuffle
GeneralPyDataProvider.__init__(self, *file_list, **kwargs)
def getSlots(self):
return self.slots
def generateData(self, f):
return handler(self, f)
def calculateDataBatchSize(self, data):
return calc_batch_size(data)
return Cls
return _wrapper
def init_hook_wrapper(func):
"""
Wrap a method for PyDataProviderWrapper's init_hook. This method can
receive parameter from trainer_config's load_data_args. The load_data_args
must pass a pickle.dumps() value, and dump a map as keyword args. The
wrapped method :code:`func` will receive them as keyword args.
So an example usage is:
.. code-block:: python
@init_hook_wrapper
def hook(obj, dictionary, file_list, **kwargs):
obj.dictionary = dictionary
obj.slots = [IndexSlot(len(obj.dictionary)),
IndexSlot(len(open(file_list[0], "r").readlines()))]
:param func: init_hook function
:type func: callable
:return: wrapped method, can be passed into @provider.
"""
@functools.wraps(func)
def wrapper(obj, *file_list, **kwargs):
args = kwargs.get("load_data_args", dict())
if isinstance(args, basestring):
args = pickle.loads(args)
args['file_list'] = file_list
func(obj=obj, **args)
return wrapper