|
|
|
@ -417,48 +417,46 @@ void AsyncExecutorThreadWorker::PrepareParams() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncExecutorThreadWorker::UpdateParams() {
|
|
|
|
|
for (auto i: _param_config->sparse_table_id) {//TODO
|
|
|
|
|
//for (int i = 0; i < 1; ++i) {
|
|
|
|
|
PushSparse(i);
|
|
|
|
|
}
|
|
|
|
|
//for (auto i = 0u; i < GlobalConfig::instance().dense_table_id.size(); ++i) {//TODO
|
|
|
|
|
for (auto i: _param_config->dense_table_id) {
|
|
|
|
|
PushDense(i);
|
|
|
|
|
}
|
|
|
|
|
int32_t tmp_push_dense_wait_times = -1;//_param_config->tmp_push_dense_wait_times; //TODO
|
|
|
|
|
int32_t tmp_push_sparse_wait_times = -1;//_param_config->tmp_push_sparse_wait_times; //TODO
|
|
|
|
|
static uint32_t push_dense_wait_times = static_cast<uint32_t>(tmp_push_dense_wait_times);
|
|
|
|
|
static uint32_t push_sparse_wait_times = static_cast<uint32_t>(tmp_push_sparse_wait_times);
|
|
|
|
|
|
|
|
|
|
if (_push_dense_status.size() >= push_dense_wait_times) {
|
|
|
|
|
for (auto& t : _push_dense_status) {
|
|
|
|
|
t.wait();
|
|
|
|
|
}
|
|
|
|
|
_push_dense_status.resize(0);
|
|
|
|
|
}
|
|
|
|
|
if (tmp_push_dense_wait_times == -1) {
|
|
|
|
|
_push_dense_status.resize(0);
|
|
|
|
|
}
|
|
|
|
|
if (_push_sparse_status.size() >= push_sparse_wait_times) {
|
|
|
|
|
for (auto& t : _push_sparse_status) {
|
|
|
|
|
t.wait();
|
|
|
|
|
}
|
|
|
|
|
_push_sparse_status.resize(0);
|
|
|
|
|
}
|
|
|
|
|
if (tmp_push_sparse_wait_times == -1) {
|
|
|
|
|
_push_sparse_status.resize(0);
|
|
|
|
|
}
|
|
|
|
|
//for (auto dense_table_id : GlobalConfig::instance().dense_table_id) {//TODO
|
|
|
|
|
for (auto dense_table_id: _param_config->dense_table_id) {
|
|
|
|
|
_pull_dense_thread->increase_thread_version(thread_id_, dense_table_id);
|
|
|
|
|
for (auto i : _param_config->sparse_table_id) {
|
|
|
|
|
PushSparse(i);
|
|
|
|
|
}
|
|
|
|
|
for (auto i : _param_config->dense_table_id) {
|
|
|
|
|
PushDense(i);
|
|
|
|
|
}
|
|
|
|
|
// _param_config->tmp_push_dense_wait_times
|
|
|
|
|
int32_t tmp_push_dense_wait_times = -1;
|
|
|
|
|
// _param_config->tmp_push_sparse_wait_times
|
|
|
|
|
int32_t tmp_push_sparse_wait_times = -1;
|
|
|
|
|
static uint32_t push_dense_wait_times =
|
|
|
|
|
static_cast<uint32_t>(tmp_push_dense_wait_times);
|
|
|
|
|
static uint32_t push_sparse_wait_times =
|
|
|
|
|
static_cast<uint32_t>(tmp_push_sparse_wait_times);
|
|
|
|
|
|
|
|
|
|
if (_push_dense_status.size() >= push_dense_wait_times) {
|
|
|
|
|
for (auto& t : _push_dense_status) {
|
|
|
|
|
t.wait();
|
|
|
|
|
}
|
|
|
|
|
_push_dense_status.resize(0);
|
|
|
|
|
}
|
|
|
|
|
if (tmp_push_dense_wait_times == -1) {
|
|
|
|
|
_push_dense_status.resize(0);
|
|
|
|
|
}
|
|
|
|
|
if (_push_sparse_status.size() >= push_sparse_wait_times) {
|
|
|
|
|
for (auto& t : _push_sparse_status) {
|
|
|
|
|
t.wait();
|
|
|
|
|
}
|
|
|
|
|
//}
|
|
|
|
|
_push_sparse_status.resize(0);
|
|
|
|
|
}
|
|
|
|
|
if (tmp_push_sparse_wait_times == -1) {
|
|
|
|
|
_push_sparse_status.resize(0);
|
|
|
|
|
}
|
|
|
|
|
for (auto dense_table_id : _param_config->dense_table_id) {
|
|
|
|
|
_pull_dense_thread->increase_thread_version(thread_id_, dense_table_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncExecutorThreadWorker::PushDense(int table_id) {
|
|
|
|
|
std::vector<paddle::ps::Region> regions;
|
|
|
|
|
//auto& variables = GlobalConfig::instance().dense_gradient_variable_name[table_id];
|
|
|
|
|
//std::vector<std::string> variables;
|
|
|
|
|
for (auto& t : _param_config->dense_gradient_variable_name[table_id]) {
|
|
|
|
|
Variable* var = thread_scope_->FindVar(t);
|
|
|
|
|
CHECK(var != nullptr) << "var[" << t << "] not found";
|
|
|
|
@ -469,7 +467,8 @@ void AsyncExecutorThreadWorker::PushDense(int table_id) {
|
|
|
|
|
regions.emplace_back(std::move(reg));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto status = _pslib_ptr->_worker_ptr->push_dense(regions.data(), regions.size(), table_id);
|
|
|
|
|
auto status = _pslib_ptr->_worker_ptr->push_dense(
|
|
|
|
|
regions.data(), regions.size(), table_id);
|
|
|
|
|
_push_dense_status.push_back(std::move(status));
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
@ -478,7 +477,7 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) {
|
|
|
|
|
|
|
|
|
|
auto& features = _features[table_id];
|
|
|
|
|
auto& feature_value = _feature_value[table_id];
|
|
|
|
|
auto fea_dim = _param_config->fea_dim; //TODO
|
|
|
|
|
auto fea_dim = _param_config->fea_dim;
|
|
|
|
|
// slot id starts from 1
|
|
|
|
|
features.clear();
|
|
|
|
|
features.resize(0);
|
|
|
|
|