|
|
|
@ -70,7 +70,7 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
|
|
|
|
|
|
|
|
|
|
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
|
|
|
|
|
uint64_t table_id = static_cast<uint64_t>(
|
|
|
|
|
param_.program_config(0).pull_sparse_table_id(table_idx));
|
|
|
|
|
param_.program_config(0).pull_sparse_table_id(table_idx));
|
|
|
|
|
|
|
|
|
|
TableParameter table;
|
|
|
|
|
for (auto i : param_.sparse_table()) {
|
|
|
|
@ -82,16 +82,23 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
|
|
|
|
|
auto& feature = features_[table_id];
|
|
|
|
|
auto& feature_label = feature_labels_[table_id];
|
|
|
|
|
feature_label.resize(feature.size());
|
|
|
|
|
VLOG(3) << "going to get label_var_name " << label_var_name_[table_id];
|
|
|
|
|
Variable* var = thread_scope_->FindVar(label_var_name_[table_id]);
|
|
|
|
|
VLOG(3) << "going to get tensor";
|
|
|
|
|
LoDTensor* tensor = var->GetMutable<LoDTensor>();
|
|
|
|
|
VLOG(3) << "going to get ptr";
|
|
|
|
|
int64_t* label_ptr = tensor->data<int64_t>();
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "lele";
|
|
|
|
|
int global_index = 0;
|
|
|
|
|
for (size_t i = 0; i < sparse_key_names_[table_id].size(); ++i) {
|
|
|
|
|
VLOG(3) << "sparse_key_names_[" << i
|
|
|
|
|
<< "]: " << sparse_key_names_[table_id][i];
|
|
|
|
|
Variable* fea_var = thread_scope_->FindVar(sparse_key_names_[table_id][i]);
|
|
|
|
|
LoDTensor* tensor = fea_var->GetMutable<LoDTensor>();
|
|
|
|
|
int64_t* ids = tensor->data<int64_t>();
|
|
|
|
|
int fea_idx = 0;
|
|
|
|
|
VLOG(3) << "Haha";
|
|
|
|
|
// tensor->lod()[0].size() == batch_size + 1
|
|
|
|
|
for (auto lod_idx = 1u; lod_idx < tensor->lod()[0].size(); ++lod_idx) {
|
|
|
|
|
for (; fea_idx < tensor->lod()[0][lod_idx]; ++fea_idx) {
|
|
|
|
@ -103,6 +110,7 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
|
|
|
|
|
static_cast<float>(label_ptr[lod_idx - 1]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
VLOG(3) << "EE";
|
|
|
|
|
}
|
|
|
|
|
CHECK(global_index == feature.size())
|
|
|
|
|
<< "expect fea info size:" << feature.size() << " real:" << global_index;
|
|
|
|
@ -110,7 +118,7 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
|
|
|
|
|
|
|
|
|
|
void DownpourWorker::FillSparseValue(size_t table_idx) {
|
|
|
|
|
uint64_t table_id = static_cast<uint64_t>(
|
|
|
|
|
param_.program_config(0).pull_sparse_table_id(table_idx));
|
|
|
|
|
param_.program_config(0).pull_sparse_table_id(table_idx));
|
|
|
|
|
|
|
|
|
|
TableParameter table;
|
|
|
|
|
for (auto i : param_.sparse_table()) {
|
|
|
|
@ -152,6 +160,11 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DownpourWorker::TrainFilesWithProfiler() {
|
|
|
|
|
VLOG(3) << "Begin to train files with profiler";
|
|
|
|
|
platform::SetNumThreads(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DownpourWorker::TrainFiles() {
|
|
|
|
|
VLOG(3) << "Begin to train files";
|
|
|
|
|
platform::SetNumThreads(1);
|
|
|
|
|