|
|
|
@ -237,11 +237,21 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
|
|
|
|
|
thread_num_ = thread_num;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::SetTrainerId(int trainer_id) {
|
|
|
|
|
trainer_id_ = trainer_id;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
|
|
|
|
|
trainer_num_ = trainer_num;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::SetFleetSendBatchSize(int64_t size) {
|
|
|
|
|
fleet_send_batch_size_ = size;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
|
|
|
|
|
#ifdef _LINUX
|
|
|
|
@ -361,8 +371,15 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
|
|
|
|
|
VLOG(3) << "GlobalShuffle() begin, thread_id=" << thread_id_;
|
|
|
|
|
auto fleet_ptr = FleetWrapper::GetInstance();
|
|
|
|
|
std::vector<std::vector<T*>> send_vec(trainer_num_);
|
|
|
|
|
std::vector<int> send_index(trainer_num_);
|
|
|
|
|
std::vector<T> local_send_vec;
|
|
|
|
|
uint64_t reserve_len = fleet_send_batch_size_ / trainer_num_;
|
|
|
|
|
for (auto& vec : send_vec) {
|
|
|
|
|
vec.reserve(fleet_send_batch_size_);
|
|
|
|
|
vec.reserve(reserve_len);
|
|
|
|
|
}
|
|
|
|
|
local_send_vec.reserve(reserve_len);
|
|
|
|
|
for (int i = 0; i < trainer_num_; ++i) {
|
|
|
|
|
send_index[i] = i;
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::future<int32_t>> total_status;
|
|
|
|
|
auto interval = GetMemoryDataInterval();
|
|
|
|
@ -373,9 +390,23 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
|
|
|
|
|
// std::string ins_id = memory_data_[i].ins_id;
|
|
|
|
|
int64_t random_num = rand_r(&rand_seed);
|
|
|
|
|
int64_t node_id = random_num % trainer_num_;
|
|
|
|
|
send_vec[node_id].push_back(&((*memory_data_)[i]));
|
|
|
|
|
if (node_id == trainer_id_) {
|
|
|
|
|
local_send_vec.push_back((*memory_data_)[i]);
|
|
|
|
|
} else {
|
|
|
|
|
send_vec[node_id].push_back(&((*memory_data_)[i]));
|
|
|
|
|
}
|
|
|
|
|
if (i % fleet_send_batch_size_ == 0 && i != 0) {
|
|
|
|
|
for (int j = 0; j < send_vec.size(); ++j) {
|
|
|
|
|
// shuffle the sequence of sending to avoid network timeout error
|
|
|
|
|
std::random_shuffle(send_index.begin(), send_index.end());
|
|
|
|
|
for (int index = 0; index < send_index.size(); ++index) {
|
|
|
|
|
int j = send_index[index];
|
|
|
|
|
if (j == trainer_id_) {
|
|
|
|
|
VLOG(3) << "send to local, ins num=" << local_send_vec.size()
|
|
|
|
|
<< ", node_id=" << j << ", thread_id=" << thread_id_;
|
|
|
|
|
shuffled_ins_->Extend(std::move(local_send_vec));
|
|
|
|
|
local_send_vec.clear();
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::string send_str;
|
|
|
|
|
SerializeIns(send_vec[j], &send_str);
|
|
|
|
|
VLOG(3) << "send str_length=" << send_str.length()
|
|
|
|
@ -388,8 +419,14 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (int j = 0; j < send_vec.size(); ++j) {
|
|
|
|
|
if (send_vec[j].size() != 0) {
|
|
|
|
|
// shuffle the sequence of sending to avoid network timeout error
|
|
|
|
|
std::random_shuffle(send_index.begin(), send_index.end());
|
|
|
|
|
for (int index = 0; index < send_index.size(); ++index) {
|
|
|
|
|
int j = send_index[index];
|
|
|
|
|
if (j == trainer_id_ && local_send_vec.size() != 0) {
|
|
|
|
|
shuffled_ins_->Extend(std::move(local_send_vec));
|
|
|
|
|
std::vector<T>().swap(local_send_vec);
|
|
|
|
|
} else if (send_vec[j].size() != 0) {
|
|
|
|
|
std::string send_str;
|
|
|
|
|
SerializeIns(send_vec[j], &send_str);
|
|
|
|
|
VLOG(3) << "send str_length=" << send_str.length() << " to node_id=" << j
|
|
|
|
|