|
|
|
@ -28,8 +28,8 @@ namespace framework {
|
|
|
|
|
|
|
|
|
|
class Dataset {
|
|
|
|
|
public:
|
|
|
|
|
Dataset() {};
|
|
|
|
|
virtual ~Dataset() {};
|
|
|
|
|
Dataset() {}
|
|
|
|
|
virtual ~Dataset() {}
|
|
|
|
|
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
|
|
|
|
|
virtual void SetThreadNum(int thread_num) = 0;
|
|
|
|
|
virtual void SetTrainerNum(int trainer_num) = 0;
|
|
|
|
@ -39,18 +39,19 @@ class Dataset {
|
|
|
|
|
virtual int GetTrainerNum() = 0;
|
|
|
|
|
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
|
|
|
|
|
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
|
|
|
|
GetReaders() = 0;
|
|
|
|
|
GetReaders() = 0;
|
|
|
|
|
virtual void LoadIntoMemory() = 0;
|
|
|
|
|
virtual void LocalShuffle() = 0;
|
|
|
|
|
virtual void GlobalShuffle() = 0;
|
|
|
|
|
virtual void CreateReaders() = 0;
|
|
|
|
|
virtual void DestroyReaders() = 0;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
virtual int ReceiveFromClient(int msg_type, int client_id,
|
|
|
|
|
const std::string& msg) = 0;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
|
template <typename T>
|
|
|
|
|
class DatasetImpl : public Dataset {
|
|
|
|
|
public:
|
|
|
|
|
DatasetImpl();
|
|
|
|
@ -69,7 +70,7 @@ class DatasetImpl : public Dataset {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
|
|
|
|
GetReaders();
|
|
|
|
|
GetReaders();
|
|
|
|
|
virtual void LoadIntoMemory();
|
|
|
|
|
virtual void LocalShuffle();
|
|
|
|
|
virtual void GlobalShuffle();
|
|
|
|
@ -82,8 +83,10 @@ class DatasetImpl : public Dataset {
|
|
|
|
|
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
|
|
|
|
|
std::vector<T> memory_data_;
|
|
|
|
|
std::mutex mutex_for_update_memory_data_;
|
|
|
|
|
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_vec_;
|
|
|
|
|
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>> shuffled_ins_out_vec_;
|
|
|
|
|
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>>
|
|
|
|
|
shuffled_ins_vec_;
|
|
|
|
|
std::vector<std::shared_ptr<paddle::framework::BlockingQueue<T>>>
|
|
|
|
|
shuffled_ins_out_vec_;
|
|
|
|
|
int thread_num_;
|
|
|
|
|
paddle::framework::DataFeedDesc data_feed_desc_;
|
|
|
|
|
std::vector<std::string> filelist_;
|
|
|
|
@ -96,6 +99,5 @@ class MultiSlotDataset : public DatasetImpl<std::vector<MultiSlotType>> {
|
|
|
|
|
virtual ~MultiSlotDataset() {}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} // end namespace framework
|
|
|
|
|
} // end namespace paddle
|
|
|
|
|