|  |  |  | @ -24,6 +24,7 @@ | 
			
		
	
		
			
				
					|  |  |  |  | namespace paddle { | 
			
		
	
		
			
				
					|  |  |  |  | namespace framework { | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // constructor
 | 
			
		
	
		
			
				
					|  |  |  |  | template <typename T> | 
			
		
	
		
			
				
					|  |  |  |  | DatasetImpl<T>::DatasetImpl() { | 
			
		
	
		
			
				
					|  |  |  |  |   thread_num_ = 1; | 
			
		
	
	
		
			
				
					|  |  |  | @ -31,37 +32,24 @@ DatasetImpl<T>::DatasetImpl() { | 
			
		
	
		
			
				
					|  |  |  |  |   file_idx_ = 0; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // set filelist, file_idx_ will reset to zero.
 | 
			
		
	
		
			
				
					|  |  |  |  | template <typename T> | 
			
		
	
		
			
				
					|  |  |  |  | void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) { | 
			
		
	
		
			
				
					|  |  |  |  |   VLOG(3) << "filelist size: " << filelist.size(); | 
			
		
	
		
			
				
					|  |  |  |  |   filelist_ = filelist; | 
			
		
	
		
			
				
					|  |  |  |  |   file_idx_ = 0; | 
			
		
	
		
			
				
					|  |  |  |  |   /*
 | 
			
		
	
		
			
				
					|  |  |  |  |   int file_cnt = filelist_.size(); | 
			
		
	
		
			
				
					|  |  |  |  |   if (thread_num_ > file_cnt) { | 
			
		
	
		
			
				
					|  |  |  |  |     VLOG(1) << "DataSet thread num = " << thread_num_ | 
			
		
	
		
			
				
					|  |  |  |  |             << ", file num = " << file_cnt | 
			
		
	
		
			
				
					|  |  |  |  |             << ". Changing DataSet thread num = " << file_cnt; | 
			
		
	
		
			
				
					|  |  |  |  |     thread_num_ = file_cnt; | 
			
		
	
		
			
				
					|  |  |  |  |   }*/ | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // buggy here, a user should set filelist first before this function
 | 
			
		
	
		
			
				
					|  |  |  |  | // not user friendly
 | 
			
		
	
		
			
				
					|  |  |  |  | // set expect thread num. actually it may change
 | 
			
		
	
		
			
				
					|  |  |  |  | template <typename T> | 
			
		
	
		
			
				
					|  |  |  |  | void DatasetImpl<T>::SetThreadNum(int thread_num) { | 
			
		
	
		
			
				
					|  |  |  |  |   VLOG(3) << "SetThreadNum thread_num=" << thread_num; | 
			
		
	
		
			
				
					|  |  |  |  |   //int file_cnt = filelist_.size();
 | 
			
		
	
		
			
				
					|  |  |  |  |   /*
 | 
			
		
	
		
			
				
					|  |  |  |  |   if (file_cnt != 0 && thread_num > file_cnt) { | 
			
		
	
		
			
				
					|  |  |  |  |     VLOG(3) << "DataSet thread num = " << thread_num | 
			
		
	
		
			
				
					|  |  |  |  |             << ", file num = " << file_cnt | 
			
		
	
		
			
				
					|  |  |  |  |             << ". Changing DataSet thread num = " << file_cnt; | 
			
		
	
		
			
				
					|  |  |  |  |     thread_num = file_cnt; | 
			
		
	
		
			
				
					|  |  |  |  |   }*/ | 
			
		
	
		
			
				
					|  |  |  |  |   thread_num_ = thread_num; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // if you run distributed, and want to do global shuffle,
 | 
			
		
	
		
			
				
					|  |  |  |  | // set this before global shuffle.
 | 
			
		
	
		
			
				
					|  |  |  |  | // be sure you call CreateReaders before SetTrainerNum
 | 
			
		
	
		
			
				
					|  |  |  |  | template <typename T> | 
			
		
	
		
			
				
					|  |  |  |  | void DatasetImpl<T>::SetTrainerNum(int trainer_num) { | 
			
		
	
		
			
				
					|  |  |  |  |   trainer_num_ = trainer_num; | 
			
		
	
	
		
			
				
					|  |  |  | @ -86,12 +74,16 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) { | 
			
		
	
		
			
				
					|  |  |  |  |                                                 &data_feed_desc_); | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // readers_.size() may not be equal to thread_num_,
 | 
			
		
	
		
			
				
					|  |  |  |  | // it changes when filelist_.size() < thread_num_
 | 
			
		
	
		
			
				
					|  |  |  |  | template <typename T> | 
			
		
	
		
			
				
					|  |  |  |  | std::vector<std::shared_ptr<paddle::framework::DataFeed>>& | 
			
		
	
		
			
				
					|  |  |  |  | DatasetImpl<T>::GetReaders() { | 
			
		
	
		
			
				
					|  |  |  |  |   return readers_; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // load data into memory, Dataset hold this memory,
 | 
			
		
	
		
			
				
					|  |  |  |  | // which will later be fed into readers' channel
 | 
			
		
	
		
			
				
					|  |  |  |  | template <typename T> | 
			
		
	
		
			
				
					|  |  |  |  | void DatasetImpl<T>::LoadIntoMemory() { | 
			
		
	
		
			
				
					|  |  |  |  |   VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin"; | 
			
		
	
	
		
			
				
					|  |  |  | @ -114,6 +106,7 @@ void DatasetImpl<T>::LoadIntoMemory() { | 
			
		
	
		
			
				
					|  |  |  |  |           << ", cost time=" << timeline.ElapsedSec() << " seconds"; | 
			
		
	
		
			
				
					|  |  |  |  | } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | // do local shuffle
 | 
			
		
	
		
			
				
					|  |  |  |  | template <typename T> | 
			
		
	
		
			
				
					|  |  |  |  | void DatasetImpl<T>::LocalShuffle() { | 
			
		
	
		
			
				
					|  |  |  |  |   VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin"; | 
			
		
	
	
		
			
				
					|  |  |  | 
 |