|
|
|
@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) {
|
|
|
|
|
bool DataFeed::PickOneFile(std::string* filename) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(mutex_for_pick_file_);
|
|
|
|
|
if (file_idx_ == filelist_.size()) {
|
|
|
|
|
VLOG(3) << "DataFeed::PickOneFile no more file to pick";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "file_idx_=" << file_idx_;
|
|
|
|
|
*filename = filelist_[file_idx_++];
|
|
|
|
|
// LOG(ERROR) << "pick file:" << *filename;
|
|
|
|
|
return true;
|
|
|
|
@ -146,17 +148,18 @@ template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
|
|
|
|
|
template <typename T>
|
|
|
|
|
InMemoryDataFeed<T>::InMemoryDataFeed() {
|
|
|
|
|
cur_channel_ = 0;
|
|
|
|
|
shuffled_ins_ = nullptr;
|
|
|
|
|
shuffled_ins_out_ = nullptr;
|
|
|
|
|
shuffled_ins_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
|
|
|
|
|
shuffled_ins_out_ = std::make_shared<paddle::framework::BlockingQueue<T>>();
|
|
|
|
|
fleet_send_batch_size_ = 10000;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
bool InMemoryDataFeed<T>::Start() {
|
|
|
|
|
DataFeed::CheckSetFileList();
|
|
|
|
|
if (memory_data_.size() != 0) {
|
|
|
|
|
CHECK_EQ(cur_channel_, 0);
|
|
|
|
|
shuffled_ins_->Extend(std::move(memory_data_));
|
|
|
|
|
std::vector<T>().swap(memory_data_);
|
|
|
|
|
if (shuffled_ins_->Size() == 0 && shuffled_ins_out_->Size() == 0) {
|
|
|
|
|
FillMemoryDataToChannel();
|
|
|
|
|
//std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
|
|
|
|
|
//std::vector<T>().swap(memory_data_);
|
|
|
|
|
}
|
|
|
|
|
DataFeed::finish_start_ = true;
|
|
|
|
|
return true;
|
|
|
|
@ -196,6 +199,31 @@ int InMemoryDataFeed<T>::Next() {
|
|
|
|
|
return DataFeed::batch_size_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::SetMemoryData(void* memory_data) {
|
|
|
|
|
memory_data_ = static_cast<std::vector<T>*>(memory_data);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::SetMemoryDataMutex(std::mutex* mutex) {
|
|
|
|
|
mutex_for_update_memory_data_ = mutex;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::SetThreadId(int thread_id) {
|
|
|
|
|
thread_id_ = thread_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
|
|
|
|
|
thread_num_ = thread_num;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
|
|
|
|
|
trainer_num_ = trainer_num;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
|
|
|
|
|
T ins;
|
|
|
|
@ -203,11 +231,54 @@ void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
|
|
|
|
|
shuffled_ins_->Push(std::move(ins));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
|
|
|
|
|
VLOG(3) << "InMemoryDataFeed<T>::FillMemoryDataToChannel, thread_id=" << thread_id_;
|
|
|
|
|
int64_t start = 0;
|
|
|
|
|
int64_t end = 0;
|
|
|
|
|
int64_t size = memory_data_->size();
|
|
|
|
|
VLOG(3) << "memory_data size=" << size;
|
|
|
|
|
for (int64_t i = 0; i <= static_cast<int64_t>(thread_id_); ++i) {
|
|
|
|
|
int64_t len = size / static_cast<int64_t>(thread_num_) +
|
|
|
|
|
(i < (size % static_cast<int64_t>(thread_num_)));
|
|
|
|
|
start = end;
|
|
|
|
|
end += len;
|
|
|
|
|
}
|
|
|
|
|
for (int64_t i = start; i < end; ++i) {
|
|
|
|
|
T& t = (*memory_data_)[i];
|
|
|
|
|
shuffled_ins_->Push(std::move(t));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::FillChannelToMemoryData() {
|
|
|
|
|
VLOG(3) << "InMemoryDataFeed<T>::FillChannelToMemoryData, thread_id=" << thread_id_;
|
|
|
|
|
std::vector<T> local_vec;
|
|
|
|
|
std::shared_ptr<paddle::framework::BlockingQueue<T>> channel = nullptr;
|
|
|
|
|
if (cur_channel_ == 0) {
|
|
|
|
|
channel = shuffled_ins_;
|
|
|
|
|
} else {
|
|
|
|
|
channel = shuffled_ins_out_;
|
|
|
|
|
}
|
|
|
|
|
CHECK(channel != nullptr);
|
|
|
|
|
local_vec.reserve(channel->Size());
|
|
|
|
|
for (int64_t i = 0; i < channel->Size(); ++i) {
|
|
|
|
|
channel->Pop(local_vec[i]);
|
|
|
|
|
}
|
|
|
|
|
std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
|
|
|
|
|
lock.lock();
|
|
|
|
|
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end());
|
|
|
|
|
lock.unlock();
|
|
|
|
|
std::vector<T>().swap(local_vec);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::LoadIntoMemory() {
|
|
|
|
|
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() begin, thread_id=" << thread_id_;
|
|
|
|
|
std::vector<T> local_vec;
|
|
|
|
|
std::string filename;
|
|
|
|
|
while (DataFeed::PickOneFile(&filename)) {
|
|
|
|
|
VLOG(3) << "PickOneFile, filename=" << filename << ", thread_id=" << thread_id_;
|
|
|
|
|
int err_no = 0;
|
|
|
|
|
PrivateQueueDataFeed<T>::fp_ =
|
|
|
|
|
fs_open_read(filename, &err_no, PrivateQueueDataFeed<T>::pipe_command_);
|
|
|
|
@ -216,35 +287,50 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
|
|
|
|
|
while (ParseOneInstanceFromPipe(&instance)) {
|
|
|
|
|
local_vec.push_back(instance);
|
|
|
|
|
}
|
|
|
|
|
memory_data_.insert(memory_data_.end(), local_vec.begin(), local_vec.end());
|
|
|
|
|
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() read all lines, thread_id=" << thread_id_;
|
|
|
|
|
{
|
|
|
|
|
std::lock_guard<std::mutex> lock(*mutex_for_update_memory_data_);
|
|
|
|
|
memory_data_->insert(memory_data_->end(), local_vec.begin(), local_vec.end());
|
|
|
|
|
}
|
|
|
|
|
std::vector<T>().swap(local_vec);
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "InMemoryDataFeed<T>::LoadIntoMemory() end, thread_id=" << thread_id_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::LocalShuffle() {
|
|
|
|
|
std::random_shuffle(memory_data_.begin(), memory_data_.end());
|
|
|
|
|
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() begin, thread_id=" << thread_id_;
|
|
|
|
|
FillMemoryDataToChannel();
|
|
|
|
|
VLOG(3) << "InMemoryDataFeed<T>::LocalShuffle() end, thread_id=" << thread_id_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// todo global shuffle
|
|
|
|
|
/*
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
|
|
|
|
|
std::random_shuffle(memory_data_.begin(), memory_data_.end());
|
|
|
|
|
for (int64_t i = 0; i < memory_data_.size(); ++i) {
|
|
|
|
|
void InMemoryDataFeed<T>::GlobalShuffle() {
|
|
|
|
|
auto fleet_ptr = FleetWrapper::GetInstance();
|
|
|
|
|
std::vector<std::string> send_str_vec(trainer_num_);
|
|
|
|
|
for (int64_t i = 0; i < memory_data_->size(); ++i) {
|
|
|
|
|
// todo get ins id
|
|
|
|
|
//std::string ins_id = memory_data_[i].ins_id;
|
|
|
|
|
// todo hash
|
|
|
|
|
int64_t hash_id = paddle::ps::local_random_engine()();
|
|
|
|
|
//int64_t hash_id = hash(ins_id);
|
|
|
|
|
//int64_t hash_id = paddle::ps::local_random_engine()();
|
|
|
|
|
int64_t hash_id = 0;
|
|
|
|
|
int64_t node_id = hash_id % trainer_num_;
|
|
|
|
|
std::string str;
|
|
|
|
|
SerializeIns(memory_data_[i], str);
|
|
|
|
|
auto fleet_ptr = FleetWrapper::GetInstance();
|
|
|
|
|
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
|
|
|
|
|
SerializeIns((*memory_data_)[i], str);
|
|
|
|
|
send_str_vec[node_id] += str;
|
|
|
|
|
if (i % fleet_send_batch_size_ == 0 && i != 0) {
|
|
|
|
|
for (int j = 0; j < send_str_vec.size(); ++j) {
|
|
|
|
|
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]);
|
|
|
|
|
send_str_vec[j] = "";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int j = 0; j < send_str_vec.size(); ++j) {
|
|
|
|
|
if (send_str_vec[j].length() != 0) {
|
|
|
|
|
fleet_ptr->send_client2client_msg(0, j, send_str_vec[j]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
// explicit instantiation
|
|
|
|
|
template class InMemoryDataFeed<std::vector<MultiSlotType>>;
|
|
|
|
@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
|
|
|
|
|
if (getline(file_, line)) {
|
|
|
|
|
int use_slots_num = use_slots_.size();
|
|
|
|
|
instance->resize(use_slots_num);
|
|
|
|
|
VLOG(3) << line;
|
|
|
|
|
// parse line
|
|
|
|
|
const char* str = line.c_str();
|
|
|
|
|
char* endptr = const_cast<char*>(str);
|
|
|
|
@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
|
|
|
|
|
// todo serialize ins in global shuffle
|
|
|
|
|
void MultiSlotInMemoryDataFeed::SerializeIns(
|
|
|
|
|
const std::vector<MultiSlotType>& ins, std::string& str) {
|
|
|
|
|
return;
|
|
|
|
|
auto fleet_ptr = FleetWrapper::GetInstance();
|
|
|
|
|
fleet_ptr->Serialize(ins, str);
|
|
|
|
|
}
|
|
|
|
|
// todo deserialize ins in global shuffle
|
|
|
|
|
void MultiSlotInMemoryDataFeed::DeserializeIns(std::vector<MultiSlotType>& ins,
|
|
|
|
|
const std::string& str) {
|
|
|
|
|
return;
|
|
|
|
|
auto fleet_ptr = FleetWrapper::GetInstance();
|
|
|
|
|
fleet_ptr->Deserialize(ins, str);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|