You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
745 lines
24 KiB
745 lines
24 KiB
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#pragma once
|
|
|
|
#if defined _WIN32 || defined __APPLE__
|
|
#else
|
|
#define _LINUX
|
|
#endif
|
|
|
|
#include <fstream>
|
|
#include <future> // NOLINT
|
|
#include <memory>
|
|
#include <mutex> // NOLINT
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <thread> // NOLINT
|
|
#include <unordered_map>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include "paddle/fluid/framework/archive.h"
|
|
#include "paddle/fluid/framework/blocking_queue.h"
|
|
#include "paddle/fluid/framework/channel.h"
|
|
#include "paddle/fluid/framework/data_feed.pb.h"
|
|
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
|
|
#include "paddle/fluid/framework/lod_tensor.h"
|
|
#include "paddle/fluid/framework/reader.h"
|
|
#include "paddle/fluid/framework/variable.h"
|
|
#include "paddle/fluid/string/string_helper.h"
|
|
|
|
namespace paddle {
|
|
namespace framework {
|
|
|
|
// DataFeed is the base virtual class for all ohther DataFeeds.
|
|
// It is used to read files and parse the data for subsequent trainer.
|
|
// Example:
|
|
// DataFeed* reader =
|
|
// paddle::framework::DataFeedFactory::CreateDataFeed(data_feed_name);
|
|
// reader->Init(data_feed_desc); // data_feed_desc is a protobuf object
|
|
// reader->SetFileList(filelist);
|
|
// const std::vector<std::string> & use_slot_alias =
|
|
// reader->GetUseSlotAlias();
|
|
// for (auto name: use_slot_alias){ // for binding memory
|
|
// reader->AddFeedVar(scope->Var(name), name);
|
|
// }
|
|
// reader->Start();
|
|
// while (reader->Next()) {
|
|
// // trainer do something
|
|
// }
|
|
union FeatureKey {
|
|
uint64_t uint64_feasign_;
|
|
float float_feasign_;
|
|
};
|
|
|
|
struct FeatureItem {
|
|
FeatureItem() {}
|
|
FeatureItem(FeatureKey sign, uint16_t slot) {
|
|
this->sign() = sign;
|
|
this->slot() = slot;
|
|
}
|
|
FeatureKey& sign() { return *(reinterpret_cast<FeatureKey*>(sign_buffer())); }
|
|
const FeatureKey& sign() const {
|
|
const FeatureKey* ret = reinterpret_cast<FeatureKey*>(sign_buffer());
|
|
return *ret;
|
|
}
|
|
uint16_t& slot() { return slot_; }
|
|
const uint16_t& slot() const { return slot_; }
|
|
|
|
private:
|
|
char* sign_buffer() const { return const_cast<char*>(sign_); }
|
|
char sign_[sizeof(FeatureKey)];
|
|
uint16_t slot_;
|
|
};
|
|
|
|
// sizeof Record is much less than std::vector<MultiSlotType>
|
|
struct Record {
|
|
std::vector<FeatureItem> uint64_feasigns_;
|
|
std::vector<FeatureItem> float_feasigns_;
|
|
std::string ins_id_;
|
|
std::string content_;
|
|
uint64_t search_id;
|
|
uint32_t rank;
|
|
uint32_t cmatch;
|
|
};
|
|
|
|
struct PvInstanceObject {
|
|
std::vector<Record*> ads;
|
|
void merge_instance(Record* ins) { ads.push_back(ins); }
|
|
};
|
|
|
|
using PvInstance = PvInstanceObject*;
|
|
|
|
inline PvInstance make_pv_instance() { return new PvInstanceObject(); }
|
|
|
|
class DataFeed {
|
|
public:
|
|
DataFeed() {
|
|
mutex_for_pick_file_ = nullptr;
|
|
file_idx_ = nullptr;
|
|
mutex_for_fea_num_ = nullptr;
|
|
total_fea_num_ = nullptr;
|
|
}
|
|
virtual ~DataFeed() {}
|
|
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
|
|
virtual bool CheckFile(const char* filename) {
|
|
PADDLE_THROW("This function(CheckFile) is not implemented.");
|
|
}
|
|
// Set filelist for DataFeed.
|
|
// Pay attention that it must init all readers before call this function.
|
|
// Otherwise, Init() function will init finish_set_filelist_ flag.
|
|
virtual bool SetFileList(const std::vector<std::string>& files);
|
|
virtual bool Start() = 0;
|
|
|
|
// The trainer calls the Next() function, and the DataFeed will load a new
|
|
// batch to the feed_vec. The return value of this function is the batch
|
|
// size of the current batch.
|
|
virtual int Next() = 0;
|
|
// Get all slots' alias which defined in protofile
|
|
virtual const std::vector<std::string>& GetAllSlotAlias() {
|
|
return all_slots_;
|
|
}
|
|
// Get used slots' alias which defined in protofile
|
|
virtual const std::vector<std::string>& GetUseSlotAlias() {
|
|
return use_slots_;
|
|
}
|
|
// This function is used for binding feed_vec memory
|
|
virtual void AddFeedVar(Variable* var, const std::string& name);
|
|
|
|
// This function is used for binding feed_vec memory in a given scope
|
|
virtual void AssignFeedVar(const Scope& scope);
|
|
|
|
// This function will do nothing at default
|
|
virtual void SetInputPvChannel(void* channel) {}
|
|
// This function will do nothing at default
|
|
virtual void SetOutputPvChannel(void* channel) {}
|
|
// This function will do nothing at default
|
|
virtual void SetConsumePvChannel(void* channel) {}
|
|
|
|
// This function will do nothing at default
|
|
virtual void SetInputChannel(void* channel) {}
|
|
// This function will do nothing at default
|
|
virtual void SetOutputChannel(void* channel) {}
|
|
// This function will do nothing at default
|
|
virtual void SetConsumeChannel(void* channel) {}
|
|
// This function will do nothing at default
|
|
virtual void SetThreadId(int thread_id) {}
|
|
// This function will do nothing at default
|
|
virtual void SetThreadNum(int thread_num) {}
|
|
// This function will do nothing at default
|
|
virtual void SetParseInsId(bool parse_ins_id) {}
|
|
virtual void SetParseContent(bool parse_content) {}
|
|
virtual void SetParseLogKey(bool parse_logkey) {}
|
|
virtual void SetEnablePvMerge(bool enable_pv_merge) {}
|
|
virtual void SetCurrentPhase(int current_phase) {}
|
|
virtual void SetFileListMutex(std::mutex* mutex) {
|
|
mutex_for_pick_file_ = mutex;
|
|
}
|
|
virtual void SetFeaNumMutex(std::mutex* mutex) { mutex_for_fea_num_ = mutex; }
|
|
virtual void SetFileListIndex(size_t* file_index) { file_idx_ = file_index; }
|
|
virtual void SetFeaNum(uint64_t* fea_num) { total_fea_num_ = fea_num; }
|
|
virtual const std::vector<std::string>& GetInsIdVec() const {
|
|
return ins_id_vec_;
|
|
}
|
|
virtual const std::vector<std::string>& GetInsContentVec() const {
|
|
return ins_content_vec_;
|
|
}
|
|
virtual int GetCurBatchSize() { return batch_size_; }
|
|
virtual void LoadIntoMemory() {
|
|
PADDLE_THROW("This function(LoadIntoMemory) is not implemented.");
|
|
}
|
|
virtual void SetPlace(const paddle::platform::Place& place) {
|
|
place_ = place;
|
|
}
|
|
virtual const paddle::platform::Place& GetPlace() const { return place_; }
|
|
|
|
protected:
|
|
// The following three functions are used to check if it is executed in this
|
|
// order:
|
|
// Init() -> SetFileList() -> Start() -> Next()
|
|
virtual void CheckInit();
|
|
virtual void CheckSetFileList();
|
|
virtual void CheckStart();
|
|
virtual void SetBatchSize(
|
|
int batch); // batch size will be set in Init() function
|
|
// This function is used to pick one file from the global filelist(thread
|
|
// safe).
|
|
virtual bool PickOneFile(std::string* filename);
|
|
virtual void CopyToFeedTensor(void* dst, const void* src, size_t size);
|
|
|
|
std::vector<std::string> filelist_;
|
|
size_t* file_idx_;
|
|
std::mutex* mutex_for_pick_file_;
|
|
std::mutex* mutex_for_fea_num_ = nullptr;
|
|
uint64_t* total_fea_num_ = nullptr;
|
|
uint64_t fea_num_ = 0;
|
|
|
|
// the alias of used slots, and its order is determined by
|
|
// data_feed_desc(proto object)
|
|
std::vector<std::string> use_slots_;
|
|
std::vector<bool> use_slots_is_dense_;
|
|
|
|
// the alias of all slots, and its order is determined by data_feed_desc(proto
|
|
// object)
|
|
std::vector<std::string> all_slots_;
|
|
std::vector<std::string> all_slots_type_;
|
|
std::vector<std::vector<int>> use_slots_shape_;
|
|
std::vector<int> inductive_shape_index_;
|
|
std::vector<int> total_dims_without_inductive_;
|
|
// For the inductive shape passed within data
|
|
std::vector<std::vector<int>> multi_inductive_shape_index_;
|
|
std::vector<int>
|
|
use_slots_index_; // -1: not used; >=0: the index of use_slots_
|
|
|
|
// The data read by DataFeed will be stored here
|
|
std::vector<LoDTensor*> feed_vec_;
|
|
|
|
LoDTensor* rank_offset_;
|
|
|
|
// the batch size defined by user
|
|
int default_batch_size_;
|
|
// current batch size
|
|
int batch_size_;
|
|
|
|
bool finish_init_;
|
|
bool finish_set_filelist_;
|
|
bool finish_start_;
|
|
std::string pipe_command_;
|
|
std::vector<std::string> ins_id_vec_;
|
|
std::vector<std::string> ins_content_vec_;
|
|
platform::Place place_;
|
|
|
|
// The input type of pipe reader, 0 for one sample, 1 for one batch
|
|
int input_type_;
|
|
};
|
|
|
|
// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds.
|
|
// It use a read-thread to read file and parse data to a private-queue
|
|
// (thread level), and get data from this queue when trainer call Next().
|
|
template <typename T>
|
|
class PrivateQueueDataFeed : public DataFeed {
|
|
public:
|
|
PrivateQueueDataFeed() {}
|
|
virtual ~PrivateQueueDataFeed() {}
|
|
virtual bool Start();
|
|
virtual int Next();
|
|
|
|
protected:
|
|
// The thread implementation function for reading file and parse.
|
|
virtual void ReadThread();
|
|
// This function is used to set private-queue size, and the most
|
|
// efficient when the queue size is close to the batch size.
|
|
virtual void SetQueueSize(int queue_size);
|
|
// The reading and parsing method called in the ReadThread.
|
|
virtual bool ParseOneInstance(T* instance) = 0;
|
|
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
|
|
// This function is used to put instance to vec_ins
|
|
virtual void AddInstanceToInsVec(T* vec_ins, const T& instance,
|
|
int index) = 0;
|
|
// This function is used to put ins_vec to feed_vec
|
|
virtual void PutToFeedVec(const T& ins_vec) = 0;
|
|
|
|
// The thread for read files
|
|
std::thread read_thread_;
|
|
// using ifstream one line and one line parse is faster
|
|
// than using fread one buffer and one buffer parse.
|
|
// for a 601M real data:
|
|
// ifstream one line and one line parse: 6034 ms
|
|
// fread one buffer and one buffer parse: 7097 ms
|
|
std::ifstream file_;
|
|
std::shared_ptr<FILE> fp_;
|
|
size_t queue_size_;
|
|
string::LineFileReader reader_;
|
|
// The queue for store parsed data
|
|
std::shared_ptr<paddle::framework::ChannelObject<T>> queue_;
|
|
};
|
|
|
|
template <typename T>
|
|
class InMemoryDataFeed : public DataFeed {
|
|
public:
|
|
InMemoryDataFeed();
|
|
virtual ~InMemoryDataFeed() {}
|
|
virtual void Init(const DataFeedDesc& data_feed_desc) = 0;
|
|
virtual bool Start();
|
|
virtual int Next();
|
|
virtual void SetInputPvChannel(void* channel);
|
|
virtual void SetOutputPvChannel(void* channel);
|
|
virtual void SetConsumePvChannel(void* channel);
|
|
|
|
virtual void SetInputChannel(void* channel);
|
|
virtual void SetOutputChannel(void* channel);
|
|
virtual void SetConsumeChannel(void* channel);
|
|
virtual void SetThreadId(int thread_id);
|
|
virtual void SetThreadNum(int thread_num);
|
|
virtual void SetParseInsId(bool parse_ins_id);
|
|
virtual void SetParseContent(bool parse_content);
|
|
virtual void SetParseLogKey(bool parse_logkey);
|
|
virtual void SetEnablePvMerge(bool enable_pv_merge);
|
|
virtual void SetCurrentPhase(int current_phase);
|
|
virtual void LoadIntoMemory();
|
|
|
|
protected:
|
|
virtual bool ParseOneInstance(T* instance) = 0;
|
|
virtual bool ParseOneInstanceFromPipe(T* instance) = 0;
|
|
virtual void PutToFeedVec(const std::vector<T>& ins_vec) = 0;
|
|
|
|
int thread_id_;
|
|
int thread_num_;
|
|
bool parse_ins_id_;
|
|
bool parse_content_;
|
|
bool parse_logkey_;
|
|
bool enable_pv_merge_;
|
|
int current_phase_{-1}; // only for untest
|
|
std::ifstream file_;
|
|
std::shared_ptr<FILE> fp_;
|
|
paddle::framework::ChannelObject<T>* input_channel_;
|
|
paddle::framework::ChannelObject<T>* output_channel_;
|
|
paddle::framework::ChannelObject<T>* consume_channel_;
|
|
|
|
paddle::framework::ChannelObject<PvInstance>* input_pv_channel_;
|
|
paddle::framework::ChannelObject<PvInstance>* output_pv_channel_;
|
|
paddle::framework::ChannelObject<PvInstance>* consume_pv_channel_;
|
|
};
|
|
|
|
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
|
|
class MultiSlotType {
|
|
public:
|
|
MultiSlotType() {}
|
|
~MultiSlotType() {}
|
|
void Init(const std::string& type, size_t reserved_size = 0) {
|
|
CheckType(type);
|
|
if (type_[0] == 'f') {
|
|
float_feasign_.clear();
|
|
if (reserved_size) {
|
|
float_feasign_.reserve(reserved_size);
|
|
}
|
|
} else if (type_[0] == 'u') {
|
|
uint64_feasign_.clear();
|
|
if (reserved_size) {
|
|
uint64_feasign_.reserve(reserved_size);
|
|
}
|
|
}
|
|
type_ = type;
|
|
}
|
|
void InitOffset(size_t max_batch_size = 0) {
|
|
if (max_batch_size > 0) {
|
|
offset_.reserve(max_batch_size + 1);
|
|
}
|
|
offset_.resize(1);
|
|
// LoDTensor' lod is counted from 0, the size of lod
|
|
// is one size larger than the size of data.
|
|
offset_[0] = 0;
|
|
}
|
|
const std::vector<size_t>& GetOffset() const { return offset_; }
|
|
std::vector<size_t>& MutableOffset() { return offset_; }
|
|
void AddValue(const float v) {
|
|
CheckFloat();
|
|
float_feasign_.push_back(v);
|
|
}
|
|
void AddValue(const uint64_t v) {
|
|
CheckUint64();
|
|
uint64_feasign_.push_back(v);
|
|
}
|
|
void CopyValues(const float* input, size_t size) {
|
|
CheckFloat();
|
|
float_feasign_.resize(size);
|
|
memcpy(float_feasign_.data(), input, size * sizeof(float));
|
|
}
|
|
void CopyValues(const uint64_t* input, size_t size) {
|
|
CheckUint64();
|
|
uint64_feasign_.resize(size);
|
|
memcpy(uint64_feasign_.data(), input, size * sizeof(uint64_t));
|
|
}
|
|
void AddIns(const MultiSlotType& ins) {
|
|
if (ins.GetType()[0] == 'f') { // float
|
|
CheckFloat();
|
|
auto& vec = ins.GetFloatData();
|
|
offset_.push_back(offset_.back() + vec.size());
|
|
float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end());
|
|
} else if (ins.GetType()[0] == 'u') { // uint64
|
|
CheckUint64();
|
|
auto& vec = ins.GetUint64Data();
|
|
offset_.push_back(offset_.back() + vec.size());
|
|
uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end());
|
|
}
|
|
}
|
|
void AppendValues(const uint64_t* input, size_t size) {
|
|
CheckUint64();
|
|
offset_.push_back(offset_.back() + size);
|
|
uint64_feasign_.insert(uint64_feasign_.end(), input, input + size);
|
|
}
|
|
void AppendValues(const float* input, size_t size) {
|
|
CheckFloat();
|
|
offset_.push_back(offset_.back() + size);
|
|
float_feasign_.insert(float_feasign_.end(), input, input + size);
|
|
}
|
|
const std::vector<float>& GetFloatData() const { return float_feasign_; }
|
|
std::vector<float>& MutableFloatData() { return float_feasign_; }
|
|
const std::vector<uint64_t>& GetUint64Data() const { return uint64_feasign_; }
|
|
std::vector<uint64_t>& MutableUint64Data() { return uint64_feasign_; }
|
|
const std::string& GetType() const { return type_; }
|
|
size_t GetBatchSize() { return offset_.size() - 1; }
|
|
std::string& MutableType() { return type_; }
|
|
|
|
std::string DebugString() {
|
|
std::stringstream ss;
|
|
ss << "\ntype: " << type_ << "\n";
|
|
ss << "offset: ";
|
|
ss << "[";
|
|
for (const size_t& i : offset_) {
|
|
ss << offset_[i] << ",";
|
|
}
|
|
ss << "]\ndata: [";
|
|
if (type_[0] == 'f') {
|
|
for (const float& i : float_feasign_) {
|
|
ss << i << ",";
|
|
}
|
|
} else {
|
|
for (const uint64_t& i : uint64_feasign_) {
|
|
ss << i << ",";
|
|
}
|
|
}
|
|
ss << "]\n";
|
|
return ss.str();
|
|
}
|
|
|
|
private:
|
|
void CheckType(const std::string& type) const {
|
|
PADDLE_ENFORCE((type == "uint64") || (type == "float"),
|
|
"There is no this type<%s>.", type);
|
|
}
|
|
void CheckFloat() const {
|
|
PADDLE_ENFORCE(type_[0] == 'f', "Add %s value to float slot.", type_);
|
|
}
|
|
void CheckUint64() const {
|
|
PADDLE_ENFORCE(type_[0] == 'u', "Add %s value to uint64 slot.", type_);
|
|
}
|
|
std::vector<float> float_feasign_;
|
|
std::vector<uint64_t> uint64_feasign_;
|
|
std::string type_;
|
|
std::vector<size_t> offset_;
|
|
};
|
|
|
|
template <class AR>
|
|
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
|
|
const MultiSlotType& ins) {
|
|
ar << ins.GetType();
|
|
#ifdef _LINUX
|
|
ar << ins.GetOffset();
|
|
#else
|
|
const auto& offset = ins.GetOffset();
|
|
ar << (uint64_t)offset.size();
|
|
for (const size_t& x : offset) {
|
|
ar << (const uint64_t)x;
|
|
}
|
|
#endif
|
|
ar << ins.GetFloatData();
|
|
ar << ins.GetUint64Data();
|
|
return ar;
|
|
}
|
|
|
|
template <class AR>
|
|
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
|
|
MultiSlotType& ins) {
|
|
ar >> ins.MutableType();
|
|
#ifdef _LINUX
|
|
ar >> ins.MutableOffset();
|
|
#else
|
|
auto& offset = ins.MutableOffset();
|
|
offset.resize(ar.template Get<uint64_t>());
|
|
for (size_t& x : offset) {
|
|
uint64_t t;
|
|
ar >> t;
|
|
x = (size_t)t;
|
|
}
|
|
#endif
|
|
ar >> ins.MutableFloatData();
|
|
ar >> ins.MutableUint64Data();
|
|
return ar;
|
|
}
|
|
|
|
struct RecordCandidate {
|
|
std::string ins_id_;
|
|
std::unordered_multimap<uint16_t, FeatureKey> feas_;
|
|
size_t shadow_index_ = -1; // Optimization for Reservoir Sample
|
|
|
|
RecordCandidate() {}
|
|
RecordCandidate(const Record& rec,
|
|
const std::unordered_set<uint16_t>& slot_index_to_replace) {
|
|
for (const auto& fea : rec.uint64_feasigns_) {
|
|
if (slot_index_to_replace.find(fea.slot()) !=
|
|
slot_index_to_replace.end()) {
|
|
feas_.insert({fea.slot(), fea.sign()});
|
|
}
|
|
}
|
|
}
|
|
|
|
RecordCandidate& operator=(const Record& rec) {
|
|
feas_.clear();
|
|
ins_id_ = rec.ins_id_;
|
|
for (auto& fea : rec.uint64_feasigns_) {
|
|
feas_.insert({fea.slot(), fea.sign()});
|
|
}
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
class RecordCandidateList {
|
|
public:
|
|
RecordCandidateList() = default;
|
|
RecordCandidateList(const RecordCandidateList&) {}
|
|
|
|
size_t Size() { return cur_size_; }
|
|
void ReSize(size_t length);
|
|
|
|
void ReInit();
|
|
void ReInitPass() {
|
|
for (size_t i = 0; i < cur_size_; ++i) {
|
|
if (candidate_list_[i].shadow_index_ != i) {
|
|
candidate_list_[i].ins_id_ =
|
|
candidate_list_[candidate_list_[i].shadow_index_].ins_id_;
|
|
candidate_list_[i].feas_.swap(
|
|
candidate_list_[candidate_list_[i].shadow_index_].feas_);
|
|
candidate_list_[i].shadow_index_ = i;
|
|
}
|
|
}
|
|
candidate_list_.resize(cur_size_);
|
|
}
|
|
|
|
void AddAndGet(const Record& record, RecordCandidate* result);
|
|
void AddAndGet(const Record& record, size_t& index_result) { // NOLINT
|
|
// std::unique_lock<std::mutex> lock(mutex_);
|
|
size_t index = 0;
|
|
++total_size_;
|
|
auto fleet_ptr = FleetWrapper::GetInstance();
|
|
if (!full_) {
|
|
candidate_list_.emplace_back(record, slot_index_to_replace_);
|
|
candidate_list_.back().shadow_index_ = cur_size_;
|
|
++cur_size_;
|
|
full_ = (cur_size_ == capacity_);
|
|
} else {
|
|
index = fleet_ptr->LocalRandomEngine()() % total_size_;
|
|
if (index < capacity_) {
|
|
candidate_list_.emplace_back(record, slot_index_to_replace_);
|
|
candidate_list_[index].shadow_index_ = candidate_list_.size() - 1;
|
|
}
|
|
}
|
|
index = fleet_ptr->LocalRandomEngine()() % cur_size_;
|
|
index_result = candidate_list_[index].shadow_index_;
|
|
}
|
|
const RecordCandidate& Get(size_t index) const {
|
|
PADDLE_ENFORCE_LT(
|
|
index, candidate_list_.size(),
|
|
platform::errors::OutOfRange("Your index [%lu] exceeds the number of "
|
|
"elements in candidate_list[%lu].",
|
|
index, candidate_list_.size()));
|
|
return candidate_list_[index];
|
|
}
|
|
void SetSlotIndexToReplace(
|
|
const std::unordered_set<uint16_t>& slot_index_to_replace) {
|
|
slot_index_to_replace_ = slot_index_to_replace;
|
|
}
|
|
|
|
private:
|
|
size_t capacity_ = 0;
|
|
std::mutex mutex_;
|
|
bool full_ = false;
|
|
size_t cur_size_ = 0;
|
|
size_t total_size_ = 0;
|
|
std::vector<RecordCandidate> candidate_list_;
|
|
std::unordered_set<uint16_t> slot_index_to_replace_;
|
|
};
|
|
|
|
template <class AR>
|
|
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
|
|
const FeatureKey& fk) {
|
|
ar << fk.uint64_feasign_;
|
|
ar << fk.float_feasign_;
|
|
return ar;
|
|
}
|
|
|
|
template <class AR>
|
|
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
|
|
FeatureKey& fk) {
|
|
ar >> fk.uint64_feasign_;
|
|
ar >> fk.float_feasign_;
|
|
return ar;
|
|
}
|
|
|
|
template <class AR>
|
|
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
|
|
const FeatureItem& fi) {
|
|
ar << fi.sign();
|
|
ar << fi.slot();
|
|
return ar;
|
|
}
|
|
|
|
template <class AR>
|
|
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
|
|
FeatureItem& fi) {
|
|
ar >> fi.sign();
|
|
ar >> fi.slot();
|
|
return ar;
|
|
}
|
|
|
|
template <class AR>
|
|
paddle::framework::Archive<AR>& operator<<(paddle::framework::Archive<AR>& ar,
|
|
const Record& r) {
|
|
ar << r.uint64_feasigns_;
|
|
ar << r.float_feasigns_;
|
|
ar << r.ins_id_;
|
|
return ar;
|
|
}
|
|
|
|
template <class AR>
|
|
paddle::framework::Archive<AR>& operator>>(paddle::framework::Archive<AR>& ar,
|
|
Record& r) {
|
|
ar >> r.uint64_feasigns_;
|
|
ar >> r.float_feasigns_;
|
|
ar >> r.ins_id_;
|
|
return ar;
|
|
}
|
|
|
|
// This DataFeed is used to feed multi-slot type data.
|
|
// The format of multi-slot type data:
|
|
// [n feasign_0 feasign_1 ... feasign_n]*
|
|
class MultiSlotDataFeed
|
|
: public PrivateQueueDataFeed<std::vector<MultiSlotType>> {
|
|
public:
|
|
MultiSlotDataFeed() {}
|
|
virtual ~MultiSlotDataFeed() {}
|
|
virtual void Init(const DataFeedDesc& data_feed_desc);
|
|
virtual bool CheckFile(const char* filename);
|
|
|
|
protected:
|
|
virtual void ReadThread();
|
|
virtual void AddInstanceToInsVec(std::vector<MultiSlotType>* vec_ins,
|
|
const std::vector<MultiSlotType>& instance,
|
|
int index);
|
|
virtual bool ParseOneInstance(std::vector<MultiSlotType>* instance);
|
|
virtual bool ParseOneInstanceFromPipe(std::vector<MultiSlotType>* instance);
|
|
virtual void PutToFeedVec(const std::vector<MultiSlotType>& ins_vec);
|
|
};
|
|
|
|
class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
|
|
public:
|
|
MultiSlotInMemoryDataFeed() {}
|
|
virtual ~MultiSlotInMemoryDataFeed() {}
|
|
virtual void Init(const DataFeedDesc& data_feed_desc);
|
|
|
|
protected:
|
|
virtual bool ParseOneInstance(Record* instance);
|
|
virtual bool ParseOneInstanceFromPipe(Record* instance);
|
|
virtual void PutToFeedVec(const std::vector<Record>& ins_vec);
|
|
virtual void GetMsgFromLogKey(const std::string& log_key, uint64_t* search_id,
|
|
uint32_t* cmatch, uint32_t* rank);
|
|
std::vector<std::vector<float>> batch_float_feasigns_;
|
|
std::vector<std::vector<uint64_t>> batch_uint64_feasigns_;
|
|
std::vector<std::vector<size_t>> offset_;
|
|
std::vector<bool> visit_;
|
|
};
|
|
|
|
class PaddleBoxDataFeed : public MultiSlotInMemoryDataFeed {
|
|
public:
|
|
PaddleBoxDataFeed() {}
|
|
virtual ~PaddleBoxDataFeed() {}
|
|
|
|
protected:
|
|
virtual void Init(const DataFeedDesc& data_feed_desc);
|
|
virtual bool Start();
|
|
virtual int Next();
|
|
virtual void AssignFeedVar(const Scope& scope);
|
|
virtual void PutToFeedVec(const std::vector<PvInstance>& pv_vec);
|
|
virtual void PutToFeedVec(const std::vector<Record*>& ins_vec);
|
|
virtual int GetCurrentPhase();
|
|
virtual void GetRankOffset(const std::vector<PvInstance>& pv_vec,
|
|
int ins_number);
|
|
std::string rank_offset_name_;
|
|
int pv_batch_size_;
|
|
};
|
|
|
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
|
template <typename T>
|
|
class PrivateInstantDataFeed : public DataFeed {
|
|
public:
|
|
PrivateInstantDataFeed() {}
|
|
virtual ~PrivateInstantDataFeed() {}
|
|
void Init(const DataFeedDesc& data_feed_desc) override;
|
|
bool Start() override { return true; }
|
|
int Next() override;
|
|
|
|
protected:
|
|
// The batched data buffer
|
|
std::vector<MultiSlotType> ins_vec_;
|
|
|
|
// This function is used to preprocess with a given filename, e.g. open it or
|
|
// mmap
|
|
virtual bool Preprocess(const std::string& filename) = 0;
|
|
|
|
// This function is used to postprocess system resource such as closing file
|
|
// NOTICE: Ensure that it is safe to call before Preprocess
|
|
virtual bool Postprocess() = 0;
|
|
|
|
// The reading and parsing method.
|
|
virtual bool ParseOneMiniBatch() = 0;
|
|
|
|
// This function is used to put ins_vec to feed_vec
|
|
virtual void PutToFeedVec();
|
|
};
|
|
|
|
class MultiSlotFileInstantDataFeed
|
|
: public PrivateInstantDataFeed<std::vector<MultiSlotType>> {
|
|
public:
|
|
MultiSlotFileInstantDataFeed() {}
|
|
virtual ~MultiSlotFileInstantDataFeed() {}
|
|
|
|
protected:
|
|
int fd_{-1};
|
|
char* buffer_{nullptr};
|
|
size_t end_{0};
|
|
size_t offset_{0};
|
|
|
|
bool Preprocess(const std::string& filename) override;
|
|
|
|
bool Postprocess() override;
|
|
|
|
bool ParseOneMiniBatch() override;
|
|
};
|
|
#endif
|
|
|
|
} // namespace framework
|
|
} // namespace paddle
|