|
|
|
@ -20,31 +20,25 @@ namespace framework {
|
|
|
|
|
std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL;
|
|
|
|
|
|
|
|
|
|
void PullDenseWorker::Initialize(const TrainerDesc& param) {
|
|
|
|
|
LOG(WARNING) << "going to initialize pull dense worker";
|
|
|
|
|
running_ = false;
|
|
|
|
|
param_ = param.pull_dense_param();
|
|
|
|
|
threshold_ = param_.threshold();
|
|
|
|
|
thread_num_ = param_.device_num();
|
|
|
|
|
sleep_time_ms_ = param_.sleep_time_ms();
|
|
|
|
|
LOG(WARNING) << "dense table size: " << param_.dense_table_size();
|
|
|
|
|
LOG(WARNING) << "thread num: " << thread_num_;
|
|
|
|
|
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
|
|
|
|
|
// setup dense variables for each table
|
|
|
|
|
int var_num = param_.dense_table(i).dense_value_name_size();
|
|
|
|
|
LOG(WARNING) << "var num: " << var_num;
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
|
|
|
|
|
dense_value_names_[tid].resize(var_num);
|
|
|
|
|
for (int j = 0; j < var_num; ++j) {
|
|
|
|
|
dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j);
|
|
|
|
|
LOG(WARNING) << "dense value names " << j << " "
|
|
|
|
|
<< dense_value_names_[tid][j];
|
|
|
|
|
}
|
|
|
|
|
// setup training version for each table
|
|
|
|
|
training_versions_[tid].resize(thread_num_, 0);
|
|
|
|
|
last_versions_[tid] = 0;
|
|
|
|
|
current_version_[tid] = 0;
|
|
|
|
|
}
|
|
|
|
|
LOG(WARNING) << "initialize pull dense worker done.";
|
|
|
|
|
fleet_ptr_ = FleetWrapper::GetInstance();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
|
|
|
|
@ -98,10 +92,7 @@ void PullDenseWorker::Run() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) {
|
|
|
|
|
LOG(WARNING) << "increase thread version input: " << thread_id << " table id "
|
|
|
|
|
<< table_id;
|
|
|
|
|
std::lock_guard<std::mutex> lock(mutex_for_version_);
|
|
|
|
|
LOG(WARNING) << "going to increase";
|
|
|
|
|
training_versions_[table_id][thread_id]++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|