|
|
|
@ -24,6 +24,7 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
// constructor
|
|
|
|
|
template <typename T>
|
|
|
|
|
DatasetImpl<T>::DatasetImpl() {
|
|
|
|
|
thread_num_ = 1;
|
|
|
|
@ -31,37 +32,24 @@ DatasetImpl<T>::DatasetImpl() {
|
|
|
|
|
file_idx_ = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set filelist, file_idx_ will reset to zero.
|
|
|
|
|
template <typename T>
|
|
|
|
|
void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
|
|
|
|
|
VLOG(3) << "filelist size: " << filelist.size();
|
|
|
|
|
filelist_ = filelist;
|
|
|
|
|
file_idx_ = 0;
|
|
|
|
|
/*
|
|
|
|
|
int file_cnt = filelist_.size();
|
|
|
|
|
if (thread_num_ > file_cnt) {
|
|
|
|
|
VLOG(1) << "DataSet thread num = " << thread_num_
|
|
|
|
|
<< ", file num = " << file_cnt
|
|
|
|
|
<< ". Changing DataSet thread num = " << file_cnt;
|
|
|
|
|
thread_num_ = file_cnt;
|
|
|
|
|
}*/
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// buggy here, a user should set filelist first before this function
|
|
|
|
|
// not user friendly
|
|
|
|
|
// set expect thread num. actually it may change
|
|
|
|
|
template <typename T>
|
|
|
|
|
void DatasetImpl<T>::SetThreadNum(int thread_num) {
|
|
|
|
|
VLOG(3) << "SetThreadNum thread_num=" << thread_num;
|
|
|
|
|
//int file_cnt = filelist_.size();
|
|
|
|
|
/*
|
|
|
|
|
if (file_cnt != 0 && thread_num > file_cnt) {
|
|
|
|
|
VLOG(3) << "DataSet thread num = " << thread_num
|
|
|
|
|
<< ", file num = " << file_cnt
|
|
|
|
|
<< ". Changing DataSet thread num = " << file_cnt;
|
|
|
|
|
thread_num = file_cnt;
|
|
|
|
|
}*/
|
|
|
|
|
thread_num_ = thread_num;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// if you run distributed, and want to do global shuffle,
|
|
|
|
|
// set this before global shuffle.
|
|
|
|
|
// be sure you call CreateReaders before SetTrainerNum
|
|
|
|
|
template <typename T>
|
|
|
|
|
void DatasetImpl<T>::SetTrainerNum(int trainer_num) {
|
|
|
|
|
trainer_num_ = trainer_num;
|
|
|
|
@ -86,12 +74,16 @@ void DatasetImpl<T>::SetDataFeedDesc(const std::string& data_feed_desc_str) {
|
|
|
|
|
&data_feed_desc_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// readers_.size() may not be equal to thread_num_,
|
|
|
|
|
// it changes when filelist_.size() < thread_num_
|
|
|
|
|
template <typename T>
|
|
|
|
|
std::vector<std::shared_ptr<paddle::framework::DataFeed>>&
|
|
|
|
|
DatasetImpl<T>::GetReaders() {
|
|
|
|
|
return readers_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// load data into memory, Dataset hold this memory,
|
|
|
|
|
// which will later be fed into readers' channel
|
|
|
|
|
template <typename T>
|
|
|
|
|
void DatasetImpl<T>::LoadIntoMemory() {
|
|
|
|
|
VLOG(3) << "DatasetImpl<T>::LoadIntoMemory() begin";
|
|
|
|
@ -114,6 +106,7 @@ void DatasetImpl<T>::LoadIntoMemory() {
|
|
|
|
|
<< ", cost time=" << timeline.ElapsedSec() << " seconds";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// do local shuffle
|
|
|
|
|
template <typename T>
|
|
|
|
|
void DatasetImpl<T>::LocalShuffle() {
|
|
|
|
|
VLOG(3) << "DatasetImpl<T>::LocalShuffle() begin";
|
|
|
|
|