remove ctr reader, all functions are satisfied in dataset (#18672)
* remove ctr reader, all functions are satisfied in datasetDDDivano-patch-1
parent
898237c19a
commit
5ed713d519
File diff suppressed because it is too large
Load Diff
@ -1,189 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <sys/time.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono> // NOLINT
|
||||
#include <cstdlib>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/reader.h"
|
||||
#include "paddle/fluid/framework/threadpool.h"
|
||||
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
enum ReaderThreadStatus { Running, Stopped };
|
||||
|
||||
struct DataDesc {
|
||||
DataDesc(int batch_size, const std::vector<std::string>& file_names,
|
||||
const std::string& file_type, const std::string& file_format,
|
||||
const std::vector<int>& dense_slot_index,
|
||||
const std::vector<int>& sparse_slot_index,
|
||||
const std::vector<std::string>& sparse_slot_ids)
|
||||
: batch_size_(batch_size),
|
||||
file_names_(file_names),
|
||||
file_type_(file_type),
|
||||
file_format_(file_format),
|
||||
dense_slot_index_(dense_slot_index),
|
||||
sparse_slot_index_(sparse_slot_index),
|
||||
sparse_slot_ids_(sparse_slot_ids) {}
|
||||
|
||||
const int batch_size_;
|
||||
const std::vector<std::string> file_names_;
|
||||
const std::string file_type_; // gzip or plain
|
||||
const std::string file_format_; // csv or svm
|
||||
// used for csv data format
|
||||
const std::vector<int> dense_slot_index_;
|
||||
const std::vector<int> sparse_slot_index_;
|
||||
// used for svm data format
|
||||
const std::vector<std::string> sparse_slot_ids_;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const DataDesc& data_desc) {
|
||||
os << "data_desc:\n";
|
||||
os << "\tbatch_size -> " << data_desc.batch_size_ << "\n";
|
||||
os << "\tfile_type -> " << data_desc.file_type_ << "\n";
|
||||
os << "\tfile_format -> " << data_desc.file_format_ << "\n";
|
||||
os << "\tfile_names -> {";
|
||||
for (auto& file_name : data_desc.file_names_) {
|
||||
os << file_name << ",";
|
||||
}
|
||||
os << "}\n";
|
||||
os << "\tdense_slot_index -> {";
|
||||
for (auto& slot : data_desc.dense_slot_index_) {
|
||||
os << slot << ",";
|
||||
}
|
||||
os << "}\n";
|
||||
os << "\tsparse_slot_index_ -> {";
|
||||
for (auto& slot : data_desc.sparse_slot_index_) {
|
||||
os << slot << ",";
|
||||
}
|
||||
os << "}\n";
|
||||
os << "\tsparse_slot_ids_ -> {";
|
||||
for (auto& slot : data_desc.sparse_slot_ids_) {
|
||||
os << slot << ",";
|
||||
}
|
||||
os << "}\n";
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
void ReadThread(const std::vector<std::string>& file_list,
|
||||
const DataDesc& data_desc, int thread_id,
|
||||
std::vector<ReaderThreadStatus>* thread_status,
|
||||
std::shared_ptr<LoDTensorBlockingQueue> queue);
|
||||
|
||||
// monitor all running thread, if they are all stopped,
|
||||
// then push an empty data into LoDTensorBlockingQueue
|
||||
void MonitorThread(std::vector<ReaderThreadStatus>* thread_status,
|
||||
std::shared_ptr<LoDTensorBlockingQueue> queue);
|
||||
|
||||
class CTRReader : public framework::FileReader {
|
||||
public:
|
||||
CTRReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue,
|
||||
int thread_num, const DataDesc& data_desc)
|
||||
: data_desc_(data_desc) {
|
||||
PADDLE_ENFORCE_GT(thread_num, 0, "thread num should be larger then 0!");
|
||||
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
|
||||
PADDLE_ENFORCE_GT(data_desc_.file_names_.size(), 0,
|
||||
"file list should not be empty");
|
||||
|
||||
thread_num_ = std::min<size_t>(data_desc_.file_names_.size(), thread_num);
|
||||
queue_ = queue;
|
||||
SplitFiles();
|
||||
for (size_t i = 0; i < thread_num_; ++i) {
|
||||
read_thread_status_.push_back(Stopped);
|
||||
}
|
||||
}
|
||||
|
||||
~CTRReader() { Shutdown(); }
|
||||
|
||||
void ReadNext(std::vector<framework::LoDTensor>* out) override {
|
||||
bool success;
|
||||
*out = queue_->Pop(&success);
|
||||
if (!success) out->clear();
|
||||
}
|
||||
|
||||
void Shutdown() override {
|
||||
VLOG(3) << "Shutdown reader";
|
||||
if (status_ == ReaderStatus::kStopped) {
|
||||
return;
|
||||
}
|
||||
// shutdown should stop all the reader thread
|
||||
for (auto& read_thread : read_threads_) {
|
||||
read_thread->join();
|
||||
}
|
||||
|
||||
if (monitor_thread_) {
|
||||
monitor_thread_->join();
|
||||
}
|
||||
|
||||
read_threads_.clear();
|
||||
monitor_thread_.reset(nullptr);
|
||||
queue_->Close();
|
||||
status_ = ReaderStatus::kStopped;
|
||||
}
|
||||
|
||||
void Start() override {
|
||||
VLOG(3) << "Start reader";
|
||||
PADDLE_ENFORCE_EQ(read_threads_.size(), 0, "read thread should be empty!");
|
||||
queue_->ReOpen();
|
||||
VLOG(3) << "reopen success";
|
||||
VLOG(3) << "thread_num " << thread_num_;
|
||||
for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) {
|
||||
read_threads_.emplace_back(new std::thread(std::bind(
|
||||
&ReadThread, file_groups_[thread_id], data_desc_,
|
||||
static_cast<int>(thread_id), &read_thread_status_, queue_)));
|
||||
}
|
||||
monitor_thread_.reset(new std::thread(
|
||||
std::bind(&MonitorThread, &read_thread_status_, queue_)));
|
||||
status_ = ReaderStatus::kRunning;
|
||||
}
|
||||
|
||||
private:
|
||||
void SplitFiles() {
|
||||
file_groups_.resize(thread_num_);
|
||||
for (size_t i = 0; i < data_desc_.file_names_.size(); ++i) {
|
||||
auto& file_name = data_desc_.file_names_[i];
|
||||
std::ifstream f(file_name.c_str());
|
||||
PADDLE_ENFORCE(f.good(), "file %s not exist!", file_name);
|
||||
file_groups_[i % thread_num_].push_back(file_name);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t thread_num_;
|
||||
const DataDesc data_desc_;
|
||||
std::shared_ptr<LoDTensorBlockingQueue> queue_;
|
||||
std::vector<std::unique_ptr<std::thread>> read_threads_;
|
||||
std::unique_ptr<std::thread> monitor_thread_;
|
||||
std::vector<ReaderThreadStatus> read_thread_status_;
|
||||
std::vector<std::vector<std::string>> file_groups_;
|
||||
};
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -1,229 +0,0 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/ctr_reader.h"
|
||||
|
||||
#include <gzstream.h>
|
||||
#include <time.h>
|
||||
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/operators/reader/blocking_queue.h"
|
||||
|
||||
using paddle::operators::reader::LoDTensorBlockingQueue;
|
||||
using paddle::operators::reader::LoDTensorBlockingQueueHolder;
|
||||
using paddle::operators::reader::CTRReader;
|
||||
using paddle::framework::LoDTensor;
|
||||
using paddle::framework::LoD;
|
||||
using paddle::framework::DDim;
|
||||
using paddle::platform::CPUPlace;
|
||||
using paddle::framework::make_ddim;
|
||||
using paddle::operators::reader::DataDesc;
|
||||
|
||||
static void generatedata(const std::vector<std::string>& data,
|
||||
const std::string& file_name) {
|
||||
std::ifstream in(file_name.c_str());
|
||||
if (in.good()) {
|
||||
VLOG(3) << "file " << file_name << " exist, delete it first!";
|
||||
remove(file_name.c_str());
|
||||
} else {
|
||||
in.close();
|
||||
}
|
||||
|
||||
ogzstream out(file_name.c_str());
|
||||
PADDLE_ENFORCE(out.good(), "open file %s failed!", file_name);
|
||||
for (auto& c : data) {
|
||||
out << c;
|
||||
}
|
||||
out.close();
|
||||
PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name);
|
||||
}
|
||||
|
||||
static inline void check_all_data(
|
||||
const std::vector<std::string>& ctr_data,
|
||||
const std::vector<std::string>& slots, const std::vector<DDim>& label_dims,
|
||||
const std::vector<int64_t>& label_value,
|
||||
const std::vector<std::tuple<LoD, std::vector<int64_t>>>& data_slot_6002,
|
||||
const std::vector<std::tuple<LoD, std::vector<int64_t>>>& data_slot_6003,
|
||||
size_t batch_num, size_t batch_size,
|
||||
std::shared_ptr<LoDTensorBlockingQueue> queue, CTRReader* reader) {
|
||||
std::vector<LoDTensor> out;
|
||||
for (size_t i = 0; i < batch_num; ++i) {
|
||||
reader->ReadNext(&out);
|
||||
ASSERT_EQ(out.size(), slots.size() + 1);
|
||||
auto& label_tensor = out.back();
|
||||
ASSERT_EQ(label_tensor.dims(), label_dims[i]);
|
||||
for (size_t j = 0; j < batch_size && i * batch_num + j < ctr_data.size();
|
||||
++j) {
|
||||
auto& label = label_tensor.data<int64_t>()[j];
|
||||
ASSERT_TRUE(label == 0 || label == 1);
|
||||
ASSERT_EQ(label, label_value[i * batch_size + j]);
|
||||
}
|
||||
auto& tensor_6002 = out[0];
|
||||
ASSERT_EQ(std::get<0>(data_slot_6002[i]), tensor_6002.lod());
|
||||
ASSERT_EQ(std::memcmp(std::get<1>(data_slot_6002[i]).data(),
|
||||
tensor_6002.data<int64_t>(),
|
||||
tensor_6002.dims()[1] * sizeof(int64_t)),
|
||||
0);
|
||||
}
|
||||
reader->ReadNext(&out);
|
||||
ASSERT_EQ(out.size(), 0);
|
||||
ASSERT_EQ(queue->Size(), 0);
|
||||
}
|
||||
|
||||
TEST(CTR_READER, read_data) {
|
||||
const std::vector<std::string> ctr_data = {
|
||||
"0 0:6002 1:6003 2:6004 3:6005 4:6006 \n",
|
||||
"0 5:6003 6:6003 7:6003 8:6004 9:6004 \n",
|
||||
"1 10:6002 11:6002 12:6002 13:6002 14:6002 \n",
|
||||
"0 15:6003 16:6003 17:6003 18:6003 19:6004 \n",
|
||||
"1 20:6001 21:6001 22:6001 23:6001 24:6001 \n",
|
||||
"1 25:6004 26:6004 27:6004 28:6005 29:6005 \n",
|
||||
"0 30:6002 31:6003 32:6004 33:6004 34:6005 \n",
|
||||
"1 35:6003 36:6003 37:6005 38:6005 39:6005 \n",
|
||||
"1 40:6002 41:6003 42:6004 43:6004 44:6005 \n",
|
||||
"1 46:6006 45:6006 47:6003 48:6003 49:6003 \n",
|
||||
};
|
||||
std::string gz_file_name = "test_ctr_reader_data.gz";
|
||||
generatedata(ctr_data, gz_file_name);
|
||||
|
||||
std::vector<int64_t> label_value = {0, 0, 1, 0, 1, 1, 0, 1, 1, 1};
|
||||
|
||||
std::tuple<LoD, std::vector<int64_t>> a1({{0, 1, 2, 7}},
|
||||
{0, 0, 10, 11, 12, 13, 14});
|
||||
std::tuple<LoD, std::vector<int64_t>> a2({{0, 1, 2, 3}}, {0, 0, 0});
|
||||
std::tuple<LoD, std::vector<int64_t>> a3({{0, 1, 2, 3}}, {30, 0, 40});
|
||||
std::tuple<LoD, std::vector<int64_t>> a4({{0, 1}}, {0});
|
||||
std::vector<std::tuple<LoD, std::vector<int64_t>>> data_slot_6002{a1, a2, a3,
|
||||
a4};
|
||||
|
||||
std::tuple<LoD, std::vector<int64_t>> b1({{0, 1, 4, 5}}, {1, 5, 6, 7, 0});
|
||||
std::tuple<LoD, std::vector<int64_t>> b2({{0, 4, 5, 6}},
|
||||
{15, 16, 17, 18, 0, 0});
|
||||
std::tuple<LoD, std::vector<int64_t>> b3({{0, 1, 3, 4}}, {31, 35, 36, 41});
|
||||
std::tuple<LoD, std::vector<int64_t>> b4({{0, 3}}, {47, 48, 49});
|
||||
std::vector<std::tuple<LoD, std::vector<int64_t>>> data_slot_6003{b1, b2, b3,
|
||||
b4};
|
||||
|
||||
std::vector<DDim> label_dims = {{3, 1}, {3, 1}, {3, 1}, {1, 1}};
|
||||
|
||||
LoDTensorBlockingQueueHolder queue_holder;
|
||||
int capacity = 64;
|
||||
queue_holder.InitOnce(capacity, false);
|
||||
|
||||
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
|
||||
|
||||
int batch_size = 3;
|
||||
int thread_num = 1;
|
||||
std::vector<std::string> sparse_slots = {"6002", "6003"};
|
||||
std::vector<std::string> file_list;
|
||||
for (int i = 0; i < thread_num; ++i) {
|
||||
file_list.push_back(gz_file_name);
|
||||
}
|
||||
|
||||
DataDesc data_desc(batch_size, file_list, "gzip", "svm", {}, {},
|
||||
sparse_slots);
|
||||
|
||||
CTRReader reader(queue, thread_num, data_desc);
|
||||
|
||||
reader.Start();
|
||||
size_t batch_num =
|
||||
std::ceil(static_cast<float>(ctr_data.size()) / batch_size) * thread_num;
|
||||
check_all_data(ctr_data, sparse_slots, label_dims, label_value,
|
||||
data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
|
||||
&reader);
|
||||
|
||||
reader.Shutdown();
|
||||
|
||||
reader.Start();
|
||||
check_all_data(ctr_data, sparse_slots, label_dims, label_value,
|
||||
data_slot_6002, data_slot_6003, batch_num, batch_size, queue,
|
||||
&reader);
|
||||
reader.Shutdown();
|
||||
}
|
||||
|
||||
static void GenereteCsvData(const std::string& file_name,
|
||||
const std::vector<std::string>& data) {
|
||||
std::ofstream out(file_name.c_str());
|
||||
PADDLE_ENFORCE(out.good(), "open file %s failed!", file_name);
|
||||
for (auto& c : data) {
|
||||
out << c;
|
||||
}
|
||||
out.close();
|
||||
PADDLE_ENFORCE(out.good(), "save file %s failed!", file_name);
|
||||
}
|
||||
|
||||
static void CheckReadCsvOut(const std::vector<LoDTensor>& out) {
|
||||
ASSERT_EQ(out.size(), 3);
|
||||
ASSERT_EQ(out[0].dims()[1], 1);
|
||||
ASSERT_EQ(out[1].dims()[1], 2);
|
||||
ASSERT_EQ(out[2].dims()[1], 1);
|
||||
for (size_t i = 0; i < out[0].numel(); ++i) {
|
||||
int64_t label = out[0].data<int64_t>()[i];
|
||||
auto& dense_dim = out[1].dims();
|
||||
for (size_t j = 0; j < dense_dim[1]; ++j) {
|
||||
ASSERT_EQ(out[1].data<float>()[i * dense_dim[1] + j],
|
||||
static_cast<float>(label + 0.1));
|
||||
}
|
||||
auto& sparse_lod = out[2].lod();
|
||||
for (size_t j = sparse_lod[0][i]; j < sparse_lod[0][i + 1]; ++j) {
|
||||
ASSERT_EQ(out[2].data<int64_t>()[j], label);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CTR_READER, read_csv_data) {
|
||||
std::string file_name = "test_ctr_reader_data.csv";
|
||||
const std::vector<std::string> csv_data = {
|
||||
"0 0.1,0.1 0,0,0,0\n", "1 1.1,1.1 1,1,1,1\n", "2 2.1,2.1 2,2,2,2\n",
|
||||
"3 3.1,3.1 3,3,3,3\n",
|
||||
};
|
||||
GenereteCsvData(file_name, csv_data);
|
||||
|
||||
LoDTensorBlockingQueueHolder queue_holder;
|
||||
int capacity = 64;
|
||||
queue_holder.InitOnce(capacity, false);
|
||||
|
||||
std::shared_ptr<LoDTensorBlockingQueue> queue = queue_holder.GetQueue();
|
||||
|
||||
int batch_size = 3;
|
||||
int thread_num = 1;
|
||||
std::vector<std::string> file_list;
|
||||
for (int i = 0; i < thread_num; ++i) {
|
||||
file_list.push_back(file_name);
|
||||
}
|
||||
DataDesc data_desc(batch_size, file_list, "plain", "csv", {1}, {2}, {});
|
||||
|
||||
CTRReader reader(queue, thread_num, data_desc);
|
||||
|
||||
for (size_t i = 0; i < 2; ++i) {
|
||||
reader.Start();
|
||||
std::vector<LoDTensor> out;
|
||||
while (true) {
|
||||
reader.ReadNext(&out);
|
||||
if (out.empty()) {
|
||||
break;
|
||||
}
|
||||
CheckReadCsvOut(out);
|
||||
}
|
||||
reader.Shutdown();
|
||||
}
|
||||
}
|
@ -1,164 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle.fluid import core
|
||||
from paddle.fluid.executor import global_scope
|
||||
from paddle.fluid.framework import default_main_program, \
|
||||
default_startup_program, Variable
|
||||
from paddle.fluid.unique_name import generate as unique_name
|
||||
|
||||
__all__ = ['ctr_reader']
|
||||
|
||||
|
||||
def monkey_patch_reader_methods(reader):
|
||||
def __get_reader__():
|
||||
scope = global_scope()
|
||||
var = scope.find_var(reader.name)
|
||||
return var.get_reader()
|
||||
|
||||
def reset():
|
||||
return __get_reader__().reset()
|
||||
|
||||
def start():
|
||||
return __get_reader__().start()
|
||||
|
||||
reader.reset = reset
|
||||
reader.start = start
|
||||
reader.stop_gradient = True
|
||||
reader.persistable = True
|
||||
return reader
|
||||
|
||||
|
||||
def _copy_reader_var_(block, var):
|
||||
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
|
||||
new_var.desc.set_shapes(var.desc.shapes())
|
||||
new_var.desc.set_dtypes(var.desc.dtypes())
|
||||
new_var.persistable = True
|
||||
return new_var
|
||||
|
||||
|
||||
def ctr_reader(
|
||||
feed_dict,
|
||||
file_type, # gzip or plain
|
||||
file_format, # csv or svm
|
||||
dense_slot_index,
|
||||
sparse_slot_index,
|
||||
capacity,
|
||||
thread_num,
|
||||
batch_size,
|
||||
file_list,
|
||||
slots,
|
||||
name=None):
|
||||
"""
|
||||
Create a CTR reader for data feeding in Python
|
||||
|
||||
This layer returns a Reader Variable.
|
||||
The Reader provides :code:`decorate_paddle_reader()` and
|
||||
:code:`decorate_tensor_provider()` to set a Python generator as the data
|
||||
source in Python side. When :code:`Executor::Run()` is invoked in C++
|
||||
side, the data from the generator would be read automatically. Unlike
|
||||
:code:`DataFeeder.feed()`, the data reading process and
|
||||
:code:`Executor::Run()` process can run in parallel using
|
||||
:code:`py_reader`. The :code:`start()` method of the Reader should be
|
||||
called when each pass begins, while the :code:`reset()` method should be
|
||||
called when the pass ends and :code:`fluid.core.EOFException` raises.
|
||||
Note that :code:`Program.clone()` method cannot clone :code:`py_reader`.
|
||||
|
||||
Args:
|
||||
feed_dict(list(variable)): a list of data variable.
|
||||
file_type('gzip'|'plain'): the type of the data file
|
||||
file_format('csv'|'svm'): csv data or svm data format.
|
||||
cvs data format is :
|
||||
label dense_fea,dense_fea sparse_fea,sparse_fea
|
||||
the svm data format is :
|
||||
label slot1:fea_sign slot2:fea_sign slot1:fea_sign
|
||||
dense_slot_index(list(int)): the index of dense slots
|
||||
sparse_slot_index(list(int)): the index of sparse slots
|
||||
capacity(int): The buffer capacity maintained by :code:`py_reader`.
|
||||
thread_num(int): the thread num to read files by cpp reader.
|
||||
batch_size(int): batch size of data.
|
||||
file_list(list(str)): List of file names that need to read.
|
||||
slots(list(int64)): list of slot id.
|
||||
name(string): The prefix Python queue name and Reader name. None will
|
||||
be generated automatically.
|
||||
|
||||
Returns:
|
||||
Variable: A Reader from which we can get feeding data.
|
||||
|
||||
Examples:
|
||||
|
||||
1. The basic usage of :code:`ctr_reader` is as follows:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
py_reader = fluid.contrib.ctr_reader.ctr_reader(
|
||||
feed_dict=datas, file_type='plain', file_format='csv',
|
||||
file_list=file_list, dense_slot_indexs=[1, 2, 3, 4], sparse_slot_indexs=[],
|
||||
capacity=64, thread_num=20, batch_size=1000, slots=[], name='ctr_reader')
|
||||
|
||||
"""
|
||||
if name is None:
|
||||
queue_name = unique_name('lod_tensor_blocking_queue')
|
||||
reader_name = unique_name('create_ctr_reader')
|
||||
else:
|
||||
queue_name = "_".join([name, "queue"])
|
||||
reader_name = "_".join([name, "reader"])
|
||||
|
||||
var = global_scope().var(queue_name)
|
||||
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity)
|
||||
|
||||
startup_blk = default_startup_program().current_block()
|
||||
reader_var = startup_blk.create_var(name=reader_name)
|
||||
startup_blk.append_op(
|
||||
type='create_ctr_reader',
|
||||
inputs={'blocking_queue': [queue_name]},
|
||||
outputs={'Out': [reader_var]},
|
||||
attrs={
|
||||
'use_data_config': False,
|
||||
'thread_num': thread_num,
|
||||
'batch_size': batch_size,
|
||||
'file_list': file_list,
|
||||
'file_type': file_type,
|
||||
'file_format': file_format,
|
||||
'dense_slot_index': dense_slot_index,
|
||||
'sparse_slot_index': sparse_slot_index,
|
||||
'sparse_slots': slots,
|
||||
'ranks': [],
|
||||
'lod_levels': [],
|
||||
'shape_concat': []
|
||||
})
|
||||
|
||||
dtypes = [data.dtype for data in feed_dict]
|
||||
reader_var.desc.set_dtypes(dtypes)
|
||||
reader_var.persistable = True
|
||||
|
||||
main_prog_reader_var = _copy_reader_var_(
|
||||
default_main_program().current_block(), reader_var)
|
||||
|
||||
reader = monkey_patch_reader_methods(main_prog_reader_var)
|
||||
|
||||
# monkey patch py_reader special methods
|
||||
reader.queue = feed_queue
|
||||
reader.exited = False
|
||||
|
||||
main_blk = default_main_program().current_block()
|
||||
main_blk.append_op(
|
||||
type='read',
|
||||
inputs={'Reader': [reader]},
|
||||
attrs={'infer_out': False},
|
||||
outputs={'Out': feed_dict})
|
||||
|
||||
return reader
|
Loading…
Reference in new issue