|
|
|
@ -69,10 +69,16 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
|
|
|
|
|
auto table = param_.sparse_table(table_idx);
|
|
|
|
|
uint64_t table_id =
|
|
|
|
|
static_cast<uint64_t>(param_.sparse_table(table_idx).table_id());
|
|
|
|
|
uint64_t table_id = static_cast<uint64_t>(
|
|
|
|
|
param_.program_config(0).pull_sparse_table_id(table_idx));
|
|
|
|
|
|
|
|
|
|
TableParameter table;
|
|
|
|
|
for (auto i : param_.sparse_table()) {
|
|
|
|
|
if (i.table_id() == table_id) {
|
|
|
|
|
table = i;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto& feature = features_[table_id];
|
|
|
|
|
auto& feature_label = feature_labels_[table_id];
|
|
|
|
|
feature_label.resize(feature.size());
|
|
|
|
@ -103,10 +109,17 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DownpourWorker::FillSparseValue(size_t table_idx) {
|
|
|
|
|
auto table = param_.sparse_table(table_idx);
|
|
|
|
|
uint64_t table_id = static_cast<uint64_t>(
|
|
|
|
|
param_.program_config(0).pull_sparse_table_id(table_idx));
|
|
|
|
|
|
|
|
|
|
TableParameter table;
|
|
|
|
|
for (auto i : param_.sparse_table()) {
|
|
|
|
|
if (i.table_id() == table_id) {
|
|
|
|
|
table = i;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
uint64_t table_id =
|
|
|
|
|
static_cast<uint64_t>(param_.sparse_table(table_idx).table_id());
|
|
|
|
|
auto& fea_value = feature_values_[table_id];
|
|
|
|
|
auto fea_idx = 0u;
|
|
|
|
|
|
|
|
|
@ -147,11 +160,20 @@ void DownpourWorker::TrainFiles() {
|
|
|
|
|
int cur_batch;
|
|
|
|
|
while ((cur_batch = device_reader_->Next()) > 0) {
|
|
|
|
|
// pull sparse here
|
|
|
|
|
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id());
|
|
|
|
|
fleet_ptr_->PullSparseVarsSync(
|
|
|
|
|
*thread_scope_, tid, sparse_key_names_[tid], &features_[tid],
|
|
|
|
|
&feature_values_[tid], param_.sparse_table(i).fea_dim());
|
|
|
|
|
for (size_t i = 0; i < param_.program_config(0).pull_sparse_table_id_size();
|
|
|
|
|
++i) {
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(
|
|
|
|
|
param_.program_config(0).pull_sparse_table_id(i));
|
|
|
|
|
TableParameter table;
|
|
|
|
|
for (auto i : param_.sparse_table()) {
|
|
|
|
|
if (i.table_id() == tid) {
|
|
|
|
|
table = i;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
fleet_ptr_->PullSparseVarsSync(*thread_scope_, tid,
|
|
|
|
|
sparse_key_names_[tid], &features_[tid],
|
|
|
|
|
&feature_values_[tid], table.fea_dim());
|
|
|
|
|
CollectLabelInfo(i);
|
|
|
|
|
FillSparseValue(i);
|
|
|
|
|
}
|
|
|
|
@ -172,17 +194,27 @@ void DownpourWorker::TrainFiles() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// push gradients here
|
|
|
|
|
for (size_t i = 0; i < param_.sparse_table_size(); ++i) {
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(param_.sparse_table(i).table_id());
|
|
|
|
|
for (size_t i = 0; i < param_.program_config(0).push_sparse_table_id_size();
|
|
|
|
|
++i) {
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(
|
|
|
|
|
param_.program_config(0).push_sparse_table_id(i));
|
|
|
|
|
TableParameter table;
|
|
|
|
|
for (auto i : param_.sparse_table()) {
|
|
|
|
|
if (i.table_id() == tid) {
|
|
|
|
|
table = i;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
fleet_ptr_->PushSparseVarsWithLabelAsync(
|
|
|
|
|
*thread_scope_, tid, features_[tid], feature_labels_[tid],
|
|
|
|
|
sparse_key_names_[tid], sparse_grad_names_[tid],
|
|
|
|
|
param_.sparse_table(i).emb_dim(), &feature_grads_[tid],
|
|
|
|
|
&push_sparse_status_);
|
|
|
|
|
sparse_key_names_[tid], sparse_grad_names_[tid], table.emb_dim(),
|
|
|
|
|
&feature_grads_[tid], &push_sparse_status_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
|
|
|
|
|
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
|
|
|
|
|
++i) {
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(
|
|
|
|
|
param_.program_config(0).push_dense_table_id(i));
|
|
|
|
|
fleet_ptr_->PushDenseVarsAsync(
|
|
|
|
|
*thread_scope_, tid, dense_grad_names_[tid], &push_sparse_status_);
|
|
|
|
|
}
|
|
|
|
@ -219,8 +251,10 @@ void DownpourWorker::TrainFiles() {
|
|
|
|
|
push_sparse_status_.resize(0);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
|
|
|
|
|
for (size_t i = 0; i < param_.program_config(0).push_dense_table_id_size();
|
|
|
|
|
++i) {
|
|
|
|
|
uint64_t tid = static_cast<uint64_t>(
|
|
|
|
|
param_.program_config(0).push_dense_table_id(i));
|
|
|
|
|
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|