You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/data_feed.cc

1183 lines
37 KiB

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if defined _WIN32 || defined __APPLE__
#else
#define _LINUX
#endif
#include "paddle/fluid/framework/data_feed.h"
#ifdef _LINUX
#include <stdio_ext.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#endif
#include <utility>
#include "gflags/gflags.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "io/fs.h"
#include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/platform/timer.h"
namespace paddle {
namespace framework {
void RecordCandidateList::ReSize(size_t length) {
_mutex.lock();
_capacity = length;
CHECK(_capacity > 0); // NOLINT
_candidate_list.clear();
_candidate_list.resize(_capacity);
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
}
void RecordCandidateList::ReInit() {
_mutex.lock();
_full = false;
_cur_size = 0;
_total_size = 0;
_mutex.unlock();
}
void RecordCandidateList::AddAndGet(const Record& record,
RecordCandidate* result) {
_mutex.lock();
size_t index = 0;
++_total_size;
auto fleet_ptr = FleetWrapper::GetInstance();
if (!_full) {
_candidate_list[_cur_size++] = record;
_full = (_cur_size == _capacity);
} else {
CHECK(_cur_size == _capacity);
index = fleet_ptr->LocalRandomEngine()() % _total_size;
if (index < _capacity) {
_candidate_list[index] = record;
}
}
index = fleet_ptr->LocalRandomEngine()() % _cur_size;
*result = _candidate_list[index];
_mutex.unlock();
}
void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (name == use_slots_[i]) {
if (var == nullptr) {
feed_vec_[i] = nullptr;
} else {
feed_vec_[i] = var->GetMutable<LoDTensor>();
}
}
}
}
bool DataFeed::SetFileList(const std::vector<std::string>& files) {
std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
CheckInit();
// Do not set finish_set_filelist_ flag,
// since a user may set file many times after init reader
filelist_.assign(files.begin(), files.end());
finish_set_filelist_ = true;
return true;
}
void DataFeed::SetBatchSize(int batch_size) {
PADDLE_ENFORCE(batch_size > 0, "Illegal batch size: %d.", batch_size);
default_batch_size_ = batch_size;
}
bool DataFeed::PickOneFile(std::string* filename) {
PADDLE_ENFORCE(mutex_for_pick_file_ != nullptr,
"should call SetFileListMutex before PickOneFile");
PADDLE_ENFORCE(file_idx_ != nullptr,
"should call SetFileListIndex before PickOneFile");
std::unique_lock<std::mutex> lock(*mutex_for_pick_file_);
if (*file_idx_ == filelist_.size()) {
VLOG(3) << "DataFeed::PickOneFile no more file to pick";
return false;
}
VLOG(3) << "file_idx_=" << *file_idx_;
*filename = filelist_[(*file_idx_)++];
return true;
}
void DataFeed::CheckInit() {
PADDLE_ENFORCE(finish_init_, "Initialization did not succeed.");
}
void DataFeed::CheckSetFileList() {
PADDLE_ENFORCE(finish_set_filelist_, "Set filelist did not succeed.");
}
void DataFeed::CheckStart() {
PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet.");
}
void DataFeed::AssignFeedVar(const Scope& scope) {
CheckInit();
for (size_t i = 0; i < use_slots_.size(); ++i) {
feed_vec_[i] = scope.FindVar(use_slots_[i])->GetMutable<LoDTensor>();
}
}
void DataFeed::CopyToFeedTensor(void* dst, const void* src, size_t size) {
if (platform::is_cpu_place(this->place_)) {
memcpy(dst, src, size);
} else {
#ifdef PADDLE_WITH_CUDA
cudaMemcpy(dst, src, size, cudaMemcpyHostToDevice);
#else
PADDLE_THROW("Not supported GPU, Please compile WITH_GPU option");
#endif
}
}
template <typename T>
void PrivateQueueDataFeed<T>::SetQueueSize(int queue_size) {
PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size);
queue_size_ = queue_size;
queue_ = paddle::framework::MakeChannel<T>();
queue_->SetCapacity(queue_size);
}
template <typename T>
bool PrivateQueueDataFeed<T>::Start() {
CheckSetFileList();
read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this);
read_thread_.detach();
finish_start_ = true;
return true;
}
template <typename T>
void PrivateQueueDataFeed<T>::ReadThread() {
#ifdef _LINUX
std::string filename;
while (PickOneFile(&filename)) {
int err_no = 0;
fp_ = fs_open_read(filename, &err_no, pipe_command_);
__fsetlocking(&*fp_, FSETLOCKING_BYCALLER);
T instance;
while (ParseOneInstanceFromPipe(&instance)) {
queue_->Put(instance);
}
}
queue_->Close();
#endif
}
template <typename T>
int PrivateQueueDataFeed<T>::Next() {
#ifdef _LINUX
CheckStart();
int index = 0;
T ins_vec;
while (index < default_batch_size_) {
T instance;
if (!queue_->Get(instance)) {
break;
}
AddInstanceToInsVec(&ins_vec, instance, index++);
}
batch_size_ = index;
if (batch_size_ != 0) {
PutToFeedVec(ins_vec);
}
return batch_size_;
#else
return 0;
#endif
}
// explicit instantiation
template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
template <typename T>
InMemoryDataFeed<T>::InMemoryDataFeed() {
this->file_idx_ = nullptr;
this->mutex_for_pick_file_ = nullptr;
this->fp_ = nullptr;
this->thread_id_ = 0;
this->thread_num_ = 1;
this->parse_ins_id_ = false;
this->parse_content_ = false;
this->input_channel_ = nullptr;
this->output_channel_ = nullptr;
this->consume_channel_ = nullptr;
}
template <typename T>
bool InMemoryDataFeed<T>::Start() {
#ifdef _LINUX
this->CheckSetFileList();
if (output_channel_->Size() == 0 && input_channel_->Size() != 0) {
std::vector<T> data;
input_channel_->Read(data);
output_channel_->Write(std::move(data));
}
#endif
this->finish_start_ = true;
return true;
}
template <typename T>
int InMemoryDataFeed<T>::Next() {
#ifdef _LINUX
this->CheckStart();
CHECK(output_channel_ != nullptr);
CHECK(consume_channel_ != nullptr);
VLOG(3) << "output_channel_ size=" << output_channel_->Size()
<< ", consume_channel_ size=" << consume_channel_->Size()
<< ", thread_id=" << thread_id_;
int index = 0;
T instance;
std::vector<T> ins_vec;
ins_vec.reserve(this->default_batch_size_);
while (index < this->default_batch_size_) {
if (output_channel_->Size() == 0) {
break;
}
output_channel_->Get(instance);
ins_vec.push_back(instance);
++index;
consume_channel_->Put(std::move(instance));
}
this->batch_size_ = index;
VLOG(3) << "batch_size_=" << this->batch_size_
<< ", thread_id=" << thread_id_;
if (this->batch_size_ != 0) {
PutToFeedVec(ins_vec);
} else {
VLOG(3) << "finish reading, output_channel_ size="
<< output_channel_->Size()
<< ", consume_channel_ size=" << consume_channel_->Size()
<< ", thread_id=" << thread_id_;
}
return this->batch_size_;
#else
return 0;
#endif
}
template <typename T>
void InMemoryDataFeed<T>::SetInputChannel(void* channel) {
input_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetOutputChannel(void* channel) {
output_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetConsumeChannel(void* channel) {
consume_channel_ = static_cast<paddle::framework::ChannelObject<T>*>(channel);
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadId(int thread_id) {
thread_id_ = thread_id;
}
template <typename T>
void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_ = thread_num;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseContent(bool parse_content) {
parse_content_ = parse_content;
}
template <typename T>
void InMemoryDataFeed<T>::SetParseInsId(bool parse_ins_id) {
parse_ins_id_ = parse_ins_id;
}
template <typename T>
void InMemoryDataFeed<T>::LoadIntoMemory() {
#ifdef _LINUX
VLOG(3) << "LoadIntoMemory() begin, thread_id=" << thread_id_;
std::string filename;
while (this->PickOneFile(&filename)) {
VLOG(3) << "PickOneFile, filename=" << filename
<< ", thread_id=" << thread_id_;
int err_no = 0;
this->fp_ = fs_open_read(filename, &err_no, this->pipe_command_);
CHECK(this->fp_ != nullptr);
__fsetlocking(&*(this->fp_), FSETLOCKING_BYCALLER);
paddle::framework::ChannelWriter<T> writer(input_channel_);
T instance;
platform::Timer timeline;
timeline.Start();
while (ParseOneInstanceFromPipe(&instance)) {
writer << std::move(instance);
instance = T();
}
writer.Flush();
timeline.Pause();
VLOG(3) << "LoadIntoMemory() read all lines, file=" << filename
<< ", cost time=" << timeline.ElapsedSec()
<< " seconds, thread_id=" << thread_id_;
}
VLOG(3) << "LoadIntoMemory() end, thread_id=" << thread_id_;
#endif
}
// explicit instantiation
template class InMemoryDataFeed<Record>;
void MultiSlotDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
// temporarily set queue size = batch size * 100
SetQueueSize(data_feed_desc.batch_size() * 100);
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
total_dims_without_inductive_.resize(all_slot_num);
inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(j);
}
if (slot.shape(j) == -1) {
inductive_shape_index_[i] = j;
}
}
}
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
}
void MultiSlotDataFeed::ReadThread() {
#ifdef _LINUX
std::string filename;
while (PickOneFile(&filename)) {
int err_no = 0;
fp_ = fs_open_read(filename, &err_no, pipe_command_);
CHECK(fp_ != nullptr);
__fsetlocking(&*fp_, FSETLOCKING_BYCALLER);
std::vector<MultiSlotType> instance;
int ins_num = 0;
while (ParseOneInstanceFromPipe(&instance)) {
ins_num++;
queue_->Put(instance);
}
VLOG(3) << "filename: " << filename << " inst num: " << ins_num;
}
queue_->Close();
#endif
}
bool MultiSlotDataFeed::CheckFile(const char* filename) {
#ifdef _LINUX
CheckInit(); // get info of slots
std::ifstream fin(filename);
if (!fin.good()) {
VLOG(1) << "error: open file<" << filename << "> fail";
return false;
}
std::string line;
int instance_cout = 0;
std::string all_slots_alias = "";
for (const auto& alias : all_slots_) {
all_slots_alias += alias + " ";
}
std::string use_slots_alias = "";
for (const auto& alias : use_slots_) {
use_slots_alias += alias + " ";
}
VLOG(3) << "total slots num: " << all_slots_.size();
VLOG(3) << "total slots alias: " << all_slots_alias;
VLOG(3) << "used slots num: " << use_slots_.size();
VLOG(3) << "used slots alias: " << use_slots_alias;
while (getline(fin, line)) {
++instance_cout;
const char* str = line.c_str();
char* endptr = const_cast<char*>(str);
int len = line.length();
for (size_t i = 0; i < all_slots_.size(); ++i) {
auto num = strtol(endptr, &endptr, 10);
if (num < 0) {
VLOG(0) << "error: the number of ids is a negative number: " << num;
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
} else if (num == 0) {
VLOG(0)
<< "error: the number of ids can not be zero, you need "
"padding it in data generator; or if there is something wrong"
" with the data, please check if the data contains unresolvable "
"characters.";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
} else if (errno == ERANGE || num > INT_MAX) {
VLOG(0) << "error: the number of ids greater than INT_MAX";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
if (all_slots_type_[i] == "float") {
for (int i = 0; i < num; ++i) {
strtof(endptr, &endptr);
if (errno == ERANGE) {
VLOG(0) << "error: the value is out of the range of "
"representable values for float";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
if (i + 1 != num && endptr - str == len) {
VLOG(0) << "error: there is a wrong with the number of ids.";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
}
} else if (all_slots_type_[i] == "uint64") {
for (int i = 0; i < num; ++i) {
strtoull(endptr, &endptr, 10);
if (errno == ERANGE) {
VLOG(0) << "error: the value is out of the range of "
"representable values for uint64_t";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
if (i + 1 != num && endptr - str == len) {
VLOG(0) << "error: there is a wrong with the number of ids.";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
}
} else {
VLOG(0) << "error: this type<" << all_slots_type_[i]
<< "> is not supported";
return false;
}
}
// It may be added '\t' character to the end of the output of reduce
// task when processes data by Hadoop(when the output of the reduce
// task of Hadoop has only one field, it will add a '\t' at the end
// of the line by default, and you can use this option to avoid it:
// `-D mapred.textoutputformat.ignoreseparator=true`), which does
// not affect the correctness of the data. Therefore, it should be
// judged that the data is not normal when the end of each line of
// data contains characters which are not spaces.
while (endptr - str != len) {
if (!isspace(*(endptr++))) {
VLOG(0)
<< "error: there is some extra characters at the end of the line.";
VLOG(0) << "please check line<" << instance_cout << "> in file<"
<< filename << ">";
return false;
}
}
}
VLOG(3) << "instances cout: " << instance_cout;
VLOG(3) << "The file format is correct";
#endif
return true;
}
bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
thread_local string::LineFileReader reader;
if (!reader.getline(&*(fp_.get()))) {
return false;
} else {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);
const char* str = reader.get();
std::string line = std::string(str);
// VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
(*instance)[idx].AddValue(feasign);
}
} else if ((*instance)[idx].GetType()[0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
(*instance)[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
// pos = line.find_first_of(' ', pos + 1);
while (line[pos + 1] != ' ') {
pos++;
}
}
}
}
return true;
}
#else
return true;
#endif
}
bool MultiSlotDataFeed::ParseOneInstance(std::vector<MultiSlotType>* instance) {
#ifdef _LINUX
std::string line;
if (getline(file_, line)) {
int use_slots_num = use_slots_.size();
instance->resize(use_slots_num);
// parse line
const char* str = line.c_str();
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
(*instance)[idx].Init(all_slots_type_[i]);
if ((*instance)[idx].GetType()[0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
(*instance)[idx].AddValue(feasign);
}
} else if ((*instance)[idx].GetType()[0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
(*instance)[idx].AddValue(feasign);
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
pos = line.find_first_of(' ', pos + 1);
}
}
}
} else {
return false;
}
#endif
return false;
}
void MultiSlotDataFeed::AddInstanceToInsVec(
std::vector<MultiSlotType>* ins_vec,
const std::vector<MultiSlotType>& instance, int index) {
#ifdef _LINUX
if (index == 0) {
ins_vec->resize(instance.size());
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].Init(instance[i].GetType());
(*ins_vec)[i].InitOffset();
}
}
for (size_t i = 0; i < instance.size(); ++i) {
(*ins_vec)[i].AddIns(instance[i]);
}
#endif
}
void MultiSlotDataFeed::PutToFeedVec(
const std::vector<MultiSlotType>& ins_vec) {
#ifdef _LINUX
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
const auto& type = ins_vec[i].GetType();
const auto& offset = ins_vec[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec[i].GetFloatData();
float* tensor_ptr =
feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, &feasign[0],
total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
#endif
}
void MultiSlotInMemoryDataFeed::Init(
const paddle::framework::DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
total_dims_without_inductive_.resize(all_slot_num);
inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
total_dims_without_inductive_[i] = 1;
inductive_shape_index_[i] = -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) > 0) {
total_dims_without_inductive_[i] *= slot.shape(j);
}
if (slot.shape(j) == -1) {
inductive_shape_index_[i] = j;
}
}
}
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
pipe_command_ = data_feed_desc.pipe_command();
finish_init_ = true;
}
bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
#ifdef _LINUX
thread_local string::LineFileReader reader;
if (!reader.getline(&*(fp_.get()))) {
return false;
} else {
const char* str = reader.get();
std::string line = std::string(str);
// VLOG(3) << line;
char* endptr = const_cast<char*>(str);
int pos = 0;
if (parse_ins_id_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
instance->ins_id_ = std::string(str + pos, len);
pos += len + 1;
VLOG(3) << "ins_id " << instance->ins_id_;
}
if (parse_content_) {
int num = strtol(&str[pos], &endptr, 10);
CHECK(num == 1); // NOLINT
pos = endptr - str + 1;
size_t len = 0;
while (str[pos + len] != ' ') {
++len;
}
instance->content_ = std::string(str + pos, len);
pos += len + 1;
VLOG(3) << "content " << instance->content_;
}
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
// if float feasign is equal to zero, ignore it
// except when slot is dense
if (fabs(feasign) < 1e-6 && !use_slots_is_dense_[i]) {
continue;
}
FeatureKey f;
f.float_feasign_ = feasign;
instance->float_feasigns_.push_back(FeatureItem(f, idx));
}
} else if (all_slots_type_[i][0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
// if uint64 feasign is equal to zero, ignore it
// except when slot is dense
if (feasign == 0 && !use_slots_is_dense_[i]) {
continue;
}
FeatureKey f;
f.uint64_feasign_ = feasign;
instance->uint64_feasigns_.push_back(FeatureItem(f, idx));
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
// pos = line.find_first_of(' ', pos + 1);
while (line[pos + 1] != ' ') {
pos++;
}
}
}
}
instance->float_feasigns_.shrink_to_fit();
instance->uint64_feasigns_.shrink_to_fit();
return true;
}
#else
return false;
#endif
}
bool MultiSlotInMemoryDataFeed::ParseOneInstance(Record* instance) {
#ifdef _LINUX
std::string line;
if (getline(file_, line)) {
VLOG(3) << line;
// parse line
const char* str = line.c_str();
char* endptr = const_cast<char*>(str);
int pos = 0;
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
int num = strtol(&str[pos], &endptr, 10);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.\nplease check this error line: %s",
str);
if (idx != -1) {
if (all_slots_type_[i][0] == 'f') { // float
for (int j = 0; j < num; ++j) {
float feasign = strtof(endptr, &endptr);
if (fabs(feasign) < 1e-6) {
continue;
}
FeatureKey f;
f.float_feasign_ = feasign;
instance->float_feasigns_.push_back(FeatureItem(f, idx));
}
} else if (all_slots_type_[i][0] == 'u') { // uint64
for (int j = 0; j < num; ++j) {
uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10);
if (feasign == 0) {
continue;
}
FeatureKey f;
f.uint64_feasign_ = feasign;
instance->uint64_feasigns_.push_back(FeatureItem(f, idx));
}
}
pos = endptr - str;
} else {
for (int j = 0; j <= num; ++j) {
pos = line.find_first_of(' ', pos + 1);
}
}
}
instance->float_feasigns_.shrink_to_fit();
instance->uint64_feasigns_.shrink_to_fit();
return true;
} else {
return false;
}
#endif
return false;
}
void MultiSlotInMemoryDataFeed::PutToFeedVec(
const std::vector<Record>& ins_vec) {
#ifdef _LINUX
std::vector<std::vector<float>> batch_float_feasigns(use_slots_.size(),
std::vector<float>());
std::vector<std::vector<uint64_t>> batch_uint64_feasigns(
use_slots_.size(), std::vector<uint64_t>());
std::vector<std::vector<size_t>> offset(use_slots_.size(),
std::vector<size_t>{0});
std::vector<bool> visit(use_slots_.size(), false);
ins_content_vec_.clear();
ins_content_vec_.reserve(ins_vec.size());
ins_id_vec_.clear();
ins_id_vec_.reserve(ins_vec.size());
for (size_t i = 0; i < ins_vec.size(); ++i) {
auto& r = ins_vec[i];
ins_id_vec_.push_back(r.ins_id_);
ins_content_vec_.push_back(r.content_);
for (auto& item : r.float_feasigns_) {
batch_float_feasigns[item.slot()].push_back(item.sign().float_feasign_);
visit[item.slot()] = true;
}
for (auto& item : r.uint64_feasigns_) {
batch_uint64_feasigns[item.slot()].push_back(item.sign().uint64_feasign_);
visit[item.slot()] = true;
}
for (size_t j = 0; j < use_slots_.size(); ++j) {
const auto& type = all_slots_type_[j];
if (visit[j]) {
visit[j] = false;
} else {
// fill slot value with default value 0
if (type[0] == 'f') { // float
batch_float_feasigns[j].push_back(0.0);
} else if (type[0] == 'u') { // uint64
batch_uint64_feasigns[j].push_back(0);
}
}
// get offset of this ins in this slot
if (type[0] == 'f') { // float
offset[j].push_back(batch_float_feasigns[j].size());
} else if (type[0] == 'u') { // uint64
offset[j].push_back(batch_uint64_feasigns[j].size());
}
}
}
for (size_t i = 0; i < use_slots_.size(); ++i) {
if (feed_vec_[i] == nullptr) {
continue;
}
int total_instance = offset[i].back();
const auto& type = all_slots_type_[i];
if (type[0] == 'f') { // float
float* feasign = batch_float_feasigns[i].data();
float* tensor_ptr =
feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
uint64_t* feasign = batch_uint64_feasigns[i].data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, feasign, total_instance * sizeof(int64_t));
}
auto& slot_offset = offset[i];
LoD data_lod{slot_offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
if (inductive_shape_index_[i] != -1) {
use_slots_shape_[i][inductive_shape_index_[i]] =
total_instance / total_dims_without_inductive_[i];
}
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
#endif
}
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
template <typename T>
void PrivateInstantDataFeed<T>::PutToFeedVec() {
for (size_t i = 0; i < use_slots_.size(); ++i) {
const auto& type = ins_vec_[i].GetType();
const auto& offset = ins_vec_[i].GetOffset();
int total_instance = static_cast<int>(offset.back());
if (type[0] == 'f') { // float
const auto& feasign = ins_vec_[i].GetFloatData();
float* tensor_ptr =
feed_vec_[i]->mutable_data<float>({total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, &feasign[0], total_instance * sizeof(float));
} else if (type[0] == 'u') { // uint64
// no uint64_t type in paddlepaddle
const auto& feasign = ins_vec_[i].GetUint64Data();
int64_t* tensor_ptr = feed_vec_[i]->mutable_data<int64_t>(
{total_instance, 1}, this->place_);
CopyToFeedTensor(tensor_ptr, &feasign[0],
total_instance * sizeof(int64_t));
}
LoD data_lod{offset};
feed_vec_[i]->set_lod(data_lod);
if (use_slots_is_dense_[i]) {
int64_t total_dims = 1;
for (const auto e : use_slots_shape_[i]) {
total_dims *= e;
}
PADDLE_ENFORCE(
total_dims == total_instance,
"The actual data size of slot[%s] doesn't match its declaration",
use_slots_[i].c_str());
feed_vec_[i]->Resize(framework::make_ddim(use_slots_shape_[i]));
}
}
}
template <typename T>
int PrivateInstantDataFeed<T>::Next() {
if (ParseOneMiniBatch()) {
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
Postprocess();
std::string filename;
if (!PickOneFile(&filename)) {
return -1;
}
if (!Preprocess(filename)) {
return -1;
}
PADDLE_ENFORCE(true == ParseOneMiniBatch(), "Fail to parse mini-batch data");
PutToFeedVec();
return ins_vec_[0].GetBatchSize();
}
template <typename T>
void PrivateInstantDataFeed<T>::Init(const DataFeedDesc& data_feed_desc) {
finish_init_ = false;
finish_set_filelist_ = false;
finish_start_ = false;
PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(),
"Multi_slot_desc has not been set.");
paddle::framework::MultiSlotDesc multi_slot_desc =
data_feed_desc.multi_slot_desc();
SetBatchSize(data_feed_desc.batch_size());
size_t all_slot_num = multi_slot_desc.slots_size();
all_slots_.resize(all_slot_num);
all_slots_type_.resize(all_slot_num);
use_slots_index_.resize(all_slot_num);
multi_inductive_shape_index_.resize(all_slot_num);
use_slots_.clear();
use_slots_is_dense_.clear();
for (size_t i = 0; i < all_slot_num; ++i) {
const auto& slot = multi_slot_desc.slots(i);
all_slots_[i] = slot.name();
all_slots_type_[i] = slot.type();
use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1;
if (slot.is_used()) {
use_slots_.push_back(all_slots_[i]);
use_slots_is_dense_.push_back(slot.is_dense());
std::vector<int> local_shape;
if (slot.is_dense()) {
for (size_t j = 0; j < slot.shape_size(); ++j) {
if (slot.shape(j) == -1) {
multi_inductive_shape_index_[i].push_back(j);
}
}
}
for (size_t j = 0; j < slot.shape_size(); ++j) {
local_shape.push_back(slot.shape(j));
}
use_slots_shape_.push_back(local_shape);
}
}
feed_vec_.resize(use_slots_.size());
ins_vec_.resize(use_slots_.size());
finish_init_ = true;
}
template class PrivateInstantDataFeed<std::vector<MultiSlotType>>;
bool MultiSlotFileInstantDataFeed::Preprocess(const std::string& filename) {
fd_ = open(filename.c_str(), O_RDONLY);
PADDLE_ENFORCE(fd_ != -1, "Fail to open file: %s", filename.c_str());
struct stat sb;
fstat(fd_, &sb);
end_ = static_cast<size_t>(sb.st_size);
buffer_ =
reinterpret_cast<char*>(mmap(NULL, end_, PROT_READ, MAP_PRIVATE, fd_, 0));
PADDLE_ENFORCE(buffer_ != MAP_FAILED, strerror(errno));
offset_ = 0;
return true;
}
bool MultiSlotFileInstantDataFeed::Postprocess() {
if (buffer_ != nullptr) {
munmap(buffer_, end_);
buffer_ = nullptr;
}
if (fd_ != -1) {
close(fd_);
fd_ = -1;
end_ = 0;
offset_ = 0;
}
return true;
}
bool MultiSlotFileInstantDataFeed::ParseOneMiniBatch() {
if (offset_ == end_) {
return false;
}
batch_size_ = 0;
while (batch_size_ < default_batch_size_ && offset_ < end_) {
for (size_t i = 0; i < use_slots_index_.size(); ++i) {
int idx = use_slots_index_[i];
char type = all_slots_type_[i][0];
uint16_t num = *reinterpret_cast<uint16_t*>(buffer_ + offset_);
PADDLE_ENFORCE(
num,
"The number of ids can not be zero, you need padding "
"it in data generator; or if there is something wrong with "
"the data, please check if the data contains unresolvable "
"characters.");
offset_ += sizeof(uint16_t);
if (idx != -1) {
int inductive_size = multi_inductive_shape_index_[i].size();
if (UNLIKELY(batch_size_ == 0)) {
ins_vec_[idx].Init(all_slots_type_[i], default_batch_size_ * num);
ins_vec_[idx].InitOffset(default_batch_size_);
uint64_t* inductive_shape =
reinterpret_cast<uint64_t*>(buffer_ + offset_);
for (int inductive_id = 0; inductive_id < inductive_size;
++inductive_id) {
use_slots_shape_[i][multi_inductive_shape_index_[i][inductive_id]] =
static_cast<int>(*(inductive_shape + inductive_id));
}
}
num -= inductive_size;
offset_ += sizeof(uint64_t) * inductive_size;
if (type == 'f') {
ins_vec_[idx].AppendValues(
reinterpret_cast<float*>(buffer_ + offset_), num);
offset_ += num * sizeof(float);
} else if (type == 'u') {
ins_vec_[idx].AppendValues(
reinterpret_cast<uint64_t*>(buffer_ + offset_), num);
offset_ += num * sizeof(uint64_t);
}
} else {
if (type == 'f') {
offset_ += num * sizeof(float);
} else if (type == 'u') {
offset_ += num * sizeof(uint64_t);
}
}
}
++batch_size_;
// OPTIMIZE: It is better to insert check codes between instances for format
// checking
}
PADDLE_ENFORCE(batch_size_ == default_batch_size_ || offset_ == end_,
"offset_ != end_");
return true;
}
#endif
} // namespace framework
} // namespace paddle