|
|
|
@ -40,22 +40,43 @@ class Dataset {
|
|
|
|
|
public:
|
|
|
|
|
Dataset() {}
|
|
|
|
|
virtual ~Dataset() {}
|
|
|
|
|
// set file list
|
|
|
|
|
virtual void SetFileList(const std::vector<std::string>& filelist) = 0;
|
|
|
|
|
// set readers' num
|
|
|
|
|
virtual void SetThreadNum(int thread_num) = 0;
|
|
|
|
|
// set workers' num
|
|
|
|
|
virtual void SetTrainerNum(int trainer_num) = 0;
|
|
|
|
|
// set fs name and ugi
|
|
|
|
|
virtual void SetHdfsConfig(const std::string& fs_name,
|
|
|
|
|
const std::string& fs_ugi) = 0;
|
|
|
|
|
// set data fedd desc, which contains:
|
|
|
|
|
// data feed name, batch size, slots
|
|
|
|
|
virtual void SetDataFeedDesc(const std::string& data_feed_desc_str) = 0;
|
|
|
|
|
// get file list
|
|
|
|
|
virtual const std::vector<std::string>& GetFileList() = 0;
|
|
|
|
|
// get thread num
|
|
|
|
|
virtual int GetThreadNum() = 0;
|
|
|
|
|
// get worker num
|
|
|
|
|
virtual int GetTrainerNum() = 0;
|
|
|
|
|
// get data fedd desc
|
|
|
|
|
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() = 0;
|
|
|
|
|
// get readers, the reader num depend both on thread num
|
|
|
|
|
// and filelist size
|
|
|
|
|
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
|
|
|
|
GetReaders() = 0;
|
|
|
|
|
// register message handler between workers
|
|
|
|
|
virtual void RegisterClientToClientMsgHandler() = 0;
|
|
|
|
|
// load all data into memory
|
|
|
|
|
virtual void LoadIntoMemory() = 0;
|
|
|
|
|
// release all memory data
|
|
|
|
|
virtual void ReleaseMemory() = 0;
|
|
|
|
|
// local shuffle data
|
|
|
|
|
virtual void LocalShuffle() = 0;
|
|
|
|
|
// global shuffle data
|
|
|
|
|
virtual void GlobalShuffle() = 0;
|
|
|
|
|
// create readers
|
|
|
|
|
virtual void CreateReaders() = 0;
|
|
|
|
|
// destroy readers
|
|
|
|
|
virtual void DestroyReaders() = 0;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
@ -84,10 +105,12 @@ class DatasetImpl : public Dataset {
|
|
|
|
|
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
|
|
|
|
|
return data_feed_desc_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
virtual std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
|
|
|
|
GetReaders();
|
|
|
|
|
|
|
|
|
|
virtual void RegisterClientToClientMsgHandler();
|
|
|
|
|
virtual void LoadIntoMemory();
|
|
|
|
|
virtual void ReleaseMemory();
|
|
|
|
|
virtual void LocalShuffle();
|
|
|
|
|
virtual void GlobalShuffle();
|
|
|
|
|
virtual void CreateReaders();
|
|
|
|
|