|
|
@ -14,12 +14,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <ThreadPool.h>
|
|
|
|
#include <fstream>
|
|
|
|
#include <fstream>
|
|
|
|
#include <memory>
|
|
|
|
#include <memory>
|
|
|
|
#include <mutex> // NOLINT
|
|
|
|
#include <mutex> // NOLINT
|
|
|
|
#include <set>
|
|
|
|
#include <set>
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
#include <thread> // NOLINT
|
|
|
|
|
|
|
|
#include <unordered_set>
|
|
|
|
#include <utility>
|
|
|
|
#include <utility>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
|
@ -63,6 +65,7 @@ class Dataset {
|
|
|
|
virtual void SetParseContent(bool parse_content) = 0;
|
|
|
|
virtual void SetParseContent(bool parse_content) = 0;
|
|
|
|
// set merge by ins id
|
|
|
|
// set merge by ins id
|
|
|
|
virtual void SetMergeByInsId(int merge_size) = 0;
|
|
|
|
virtual void SetMergeByInsId(int merge_size) = 0;
|
|
|
|
|
|
|
|
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns) = 0;
|
|
|
|
// set fea eval mode
|
|
|
|
// set fea eval mode
|
|
|
|
virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
|
|
|
|
virtual void SetFeaEval(bool fea_eval, int record_candidate_size) = 0;
|
|
|
|
// get file list
|
|
|
|
// get file list
|
|
|
@ -112,6 +115,11 @@ class Dataset {
|
|
|
|
virtual int64_t GetShuffleDataSize() = 0;
|
|
|
|
virtual int64_t GetShuffleDataSize() = 0;
|
|
|
|
// merge by ins id
|
|
|
|
// merge by ins id
|
|
|
|
virtual void MergeByInsId() = 0;
|
|
|
|
virtual void MergeByInsId() = 0;
|
|
|
|
|
|
|
|
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
|
|
|
|
|
|
|
|
int read_thread_num,
|
|
|
|
|
|
|
|
int consume_thread_num,
|
|
|
|
|
|
|
|
int shard_num) = 0;
|
|
|
|
|
|
|
|
virtual void ClearLocalTables() = 0;
|
|
|
|
// create preload readers
|
|
|
|
// create preload readers
|
|
|
|
virtual void CreatePreLoadReaders() = 0;
|
|
|
|
virtual void CreatePreLoadReaders() = 0;
|
|
|
|
// destroy preload readers after prelaod done
|
|
|
|
// destroy preload readers after prelaod done
|
|
|
@ -148,7 +156,7 @@ class DatasetImpl : public Dataset {
|
|
|
|
virtual void SetParseInsId(bool parse_ins_id);
|
|
|
|
virtual void SetParseInsId(bool parse_ins_id);
|
|
|
|
virtual void SetParseContent(bool parse_content);
|
|
|
|
virtual void SetParseContent(bool parse_content);
|
|
|
|
virtual void SetMergeByInsId(int merge_size);
|
|
|
|
virtual void SetMergeByInsId(int merge_size);
|
|
|
|
|
|
|
|
virtual void SetGenerateUniqueFeasign(bool gen_uni_feasigns);
|
|
|
|
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
|
|
|
|
virtual void SetFeaEval(bool fea_eval, int record_candidate_size);
|
|
|
|
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
|
|
|
|
virtual const std::vector<std::string>& GetFileList() { return filelist_; }
|
|
|
|
virtual int GetThreadNum() { return thread_num_; }
|
|
|
|
virtual int GetThreadNum() { return thread_num_; }
|
|
|
@ -179,6 +187,11 @@ class DatasetImpl : public Dataset {
|
|
|
|
virtual int64_t GetMemoryDataSize();
|
|
|
|
virtual int64_t GetMemoryDataSize();
|
|
|
|
virtual int64_t GetShuffleDataSize();
|
|
|
|
virtual int64_t GetShuffleDataSize();
|
|
|
|
virtual void MergeByInsId() {}
|
|
|
|
virtual void MergeByInsId() {}
|
|
|
|
|
|
|
|
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
|
|
|
|
|
|
|
|
int read_thread_num,
|
|
|
|
|
|
|
|
int consume_thread_num,
|
|
|
|
|
|
|
|
int shard_num) {}
|
|
|
|
|
|
|
|
virtual void ClearLocalTables() {}
|
|
|
|
virtual void CreatePreLoadReaders();
|
|
|
|
virtual void CreatePreLoadReaders();
|
|
|
|
virtual void DestroyPreLoadReaders();
|
|
|
|
virtual void DestroyPreLoadReaders();
|
|
|
|
virtual void SetPreLoadThreadNum(int thread_num);
|
|
|
|
virtual void SetPreLoadThreadNum(int thread_num);
|
|
|
@ -195,6 +208,7 @@ class DatasetImpl : public Dataset {
|
|
|
|
int channel_num_;
|
|
|
|
int channel_num_;
|
|
|
|
std::vector<paddle::framework::Channel<T>> multi_output_channel_;
|
|
|
|
std::vector<paddle::framework::Channel<T>> multi_output_channel_;
|
|
|
|
std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
|
|
|
|
std::vector<paddle::framework::Channel<T>> multi_consume_channel_;
|
|
|
|
|
|
|
|
std::vector<std::unordered_set<uint64_t>> local_tables_;
|
|
|
|
// when read ins, we put ins from one channel to the other,
|
|
|
|
// when read ins, we put ins from one channel to the other,
|
|
|
|
// and when finish reading, we set cur_channel = 1 - cur_channel,
|
|
|
|
// and when finish reading, we set cur_channel = 1 - cur_channel,
|
|
|
|
// so if cur_channel=0, all data are in output_channel, else consume_channel
|
|
|
|
// so if cur_channel=0, all data are in output_channel, else consume_channel
|
|
|
@ -202,6 +216,7 @@ class DatasetImpl : public Dataset {
|
|
|
|
std::vector<T> slots_shuffle_original_data_;
|
|
|
|
std::vector<T> slots_shuffle_original_data_;
|
|
|
|
RecordCandidateList slots_shuffle_rclist_;
|
|
|
|
RecordCandidateList slots_shuffle_rclist_;
|
|
|
|
int thread_num_;
|
|
|
|
int thread_num_;
|
|
|
|
|
|
|
|
int pull_sparse_to_local_thread_num_;
|
|
|
|
paddle::framework::DataFeedDesc data_feed_desc_;
|
|
|
|
paddle::framework::DataFeedDesc data_feed_desc_;
|
|
|
|
int trainer_num_;
|
|
|
|
int trainer_num_;
|
|
|
|
std::vector<std::string> filelist_;
|
|
|
|
std::vector<std::string> filelist_;
|
|
|
@ -217,9 +232,11 @@ class DatasetImpl : public Dataset {
|
|
|
|
bool parse_content_;
|
|
|
|
bool parse_content_;
|
|
|
|
size_t merge_size_;
|
|
|
|
size_t merge_size_;
|
|
|
|
bool slots_shuffle_fea_eval_ = false;
|
|
|
|
bool slots_shuffle_fea_eval_ = false;
|
|
|
|
|
|
|
|
bool gen_uni_feasigns_ = false;
|
|
|
|
int preload_thread_num_;
|
|
|
|
int preload_thread_num_;
|
|
|
|
std::mutex global_index_mutex_;
|
|
|
|
std::mutex global_index_mutex_;
|
|
|
|
int64_t global_index_ = 0;
|
|
|
|
int64_t global_index_ = 0;
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<ThreadPool>> consume_task_pool_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// use std::vector<MultiSlotType> or Record as data type
|
|
|
|
// use std::vector<MultiSlotType> or Record as data type
|
|
|
@ -227,6 +244,16 @@ class MultiSlotDataset : public DatasetImpl<Record> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
MultiSlotDataset() {}
|
|
|
|
MultiSlotDataset() {}
|
|
|
|
virtual void MergeByInsId();
|
|
|
|
virtual void MergeByInsId();
|
|
|
|
|
|
|
|
virtual void GenerateLocalTablesUnlock(int table_id, int feadim,
|
|
|
|
|
|
|
|
int read_thread_num,
|
|
|
|
|
|
|
|
int consume_thread_num, int shard_num);
|
|
|
|
|
|
|
|
virtual void ClearLocalTables() {
|
|
|
|
|
|
|
|
for (auto& t : local_tables_) {
|
|
|
|
|
|
|
|
t.clear();
|
|
|
|
|
|
|
|
std::unordered_set<uint64_t>().swap(t);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<std::unordered_set<uint64_t>>().swap(local_tables_);
|
|
|
|
|
|
|
|
}
|
|
|
|
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
|
|
|
|
virtual void SlotsShuffle(const std::set<std::string>& slots_to_replace);
|
|
|
|
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
|
|
|
|
virtual void GetRandomData(const std::set<uint16_t>& slots_to_replace,
|
|
|
|
std::vector<Record>* result);
|
|
|
|
std::vector<Record>* result);
|
|
|
|