You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/contrib/reader/ctr_reader.py

165 lines
5.7 KiB

# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import core
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import default_main_program, \
default_startup_program, Variable
from paddle.fluid.unique_name import generate as unique_name
__all__ = ['ctr_reader']
def monkey_patch_reader_methods(reader):
def __get_reader__():
scope = global_scope()
var = scope.find_var(reader.name)
return var.get_reader()
def reset():
return __get_reader__().reset()
def start():
return __get_reader__().start()
reader.reset = reset
reader.start = start
reader.stop_gradient = True
reader.persistable = True
return reader
def _copy_reader_var_(block, var):
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
new_var.persistable = True
return new_var
def ctr_reader(
feed_dict,
file_type, # gzip or plain
file_format, # csv or svm
dense_slot_index,
sparse_slot_index,
capacity,
thread_num,
batch_size,
file_list,
slots,
name=None):
"""
Create a CTR reader for data feeding in Python
This layer returns a Reader Variable.
The Reader provides :code:`decorate_paddle_reader()` and
:code:`decorate_tensor_provider()` to set a Python generator as the data
source in Python side. When :code:`Executor::Run()` is invoked in C++
side, the data from the generator would be read automatically. Unlike
:code:`DataFeeder.feed()`, the data reading process and
:code:`Executor::Run()` process can run in parallel using
:code:`py_reader`. The :code:`start()` method of the Reader should be
called when each pass begins, while the :code:`reset()` method should be
called when the pass ends and :code:`fluid.core.EOFException` raises.
Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.
Args:
feed_dict(list(variable)): a list of data variable.
file_type('gzip'|'plain'): the type of the data file
file_format('csv'|'svm'): csv data or svm data format.
cvs data format is :
label dense_fea,dense_fea sparse_fea,sparse_fea
the svm data format is :
label slot1:fea_sign slot2:fea_sign slot1:fea_sign
dense_slot_index(list(int)): the index of dense slots
sparse_slot_index(list(int)): the index of sparse slots
capacity(int): The buffer capacity maintained by :code:`py_reader`.
thread_num(int): the thread num to read files by cpp reader.
batch_size(int): batch size of data.
file_list(list(str)): List of file names that need to read.
slots(list(int64)): list of slot id.
name(string): The prefix Python queue name and Reader name. None will
be generated automatically.
Returns:
Variable: A Reader from which we can get feeding data.
Examples:
1. The basic usage of :code:`ctr_reader` is as follows:
.. code-block:: python
py_reader = fluid.contrib.ctr_reader.ctr_reader(
feed_dict=datas, file_type='plain', file_format='csv',
file_list=file_list, dense_slot_indexs=[1, 2, 3, 4], sparse_slot_indexs=[],
capacity=64, thread_num=20, batch_size=1000, slots=[], name='ctr_reader')
"""
if name is None:
queue_name = unique_name('lod_tensor_blocking_queue')
reader_name = unique_name('create_ctr_reader')
else:
queue_name = "_".join([name, "queue"])
reader_name = "_".join([name, "reader"])
var = global_scope().var(queue_name)
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
startup_blk = default_startup_program().current_block()
reader_var = startup_blk.create_var(name=reader_name)
startup_blk.append_op(
type='create_ctr_reader',
inputs={'blocking_queue': [queue_name]},
outputs={'Out': [reader_var]},
attrs={
'use_data_config': False,
'thread_num': thread_num,
'batch_size': batch_size,
'file_list': file_list,
'file_type': file_type,
'file_format': file_format,
'dense_slot_index': dense_slot_index,
'sparse_slot_index': sparse_slot_index,
'sparse_slots': slots,
'ranks': [],
'lod_levels': [],
'shape_concat': []
})
dtypes = [data.dtype for data in feed_dict]
reader_var.desc.set_dtypes(dtypes)
reader_var.persistable = True
main_prog_reader_var = _copy_reader_var_(
default_main_program().current_block(), reader_var)
reader = monkey_patch_reader_methods(main_prog_reader_var)
# monkey patch py_reader special methods
reader.queue = feed_queue
reader.exited = False
main_blk = default_main_program().current_block()
main_blk.append_op(
type='read',
inputs={'Reader': [reader]},
attrs={'infer_out': False},
outputs={'Out': feed_dict})
return reader