|
|
|
@ -66,15 +66,20 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
|
|
|
|
|
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib());
|
|
|
|
|
_pslib_ptr->init_server(dist_desc, index);//TODO done
|
|
|
|
|
|
|
|
|
|
_pslib_ptr =
|
|
|
|
|
std::shared_ptr<paddle::distributed::PSlib>(
|
|
|
|
|
new paddle::distributed::PSlib());
|
|
|
|
|
_pslib_ptr->init_server(dist_desc, index);
|
|
|
|
|
InitParamConfig();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncExecutor::InitWorker(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index) {
|
|
|
|
|
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib());
|
|
|
|
|
_pslib_ptr->init_worker(dist_desc, host_sign_list.data(), node_num, index);//TODO done
|
|
|
|
|
void AsyncExecutor::InitWorker(const std::string& dist_desc,
|
|
|
|
|
const std::vector<uint64_t>& host_sign_list,
|
|
|
|
|
int node_num, int index) {
|
|
|
|
|
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(
|
|
|
|
|
new paddle::distributed::PSlib());
|
|
|
|
|
_pslib_ptr->init_worker(
|
|
|
|
|
dist_desc, host_sign_list.data(), node_num, index);
|
|
|
|
|
|
|
|
|
|
InitParamConfig();
|
|
|
|
|
}
|
|
|
|
@ -87,43 +92,65 @@ void AsyncExecutor::StopServer() {
|
|
|
|
|
_pslib_ptr->stop_server();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncExecutor::GatherServers(std::vector<uint64_t>& host_sign_list, int node_num) {
|
|
|
|
|
void AsyncExecutor::GatherServers(
|
|
|
|
|
std::vector<uint64_t>& host_sign_list, int node_num) {
|
|
|
|
|
_pslib_ptr->gather_servers(host_sign_list.data(), node_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncExecutor::InitParamConfig() {
|
|
|
|
|
for (int i = 0; i < _pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param_size(); ++i) {
|
|
|
|
|
if (_pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param(i).table_class().find("SparseTable") != -1) {
|
|
|
|
|
_param_config.fea_dim = _pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param(i).accessor().fea_dim(); //TODO
|
|
|
|
|
for (int i = 0; i <
|
|
|
|
|
_pslib_ptr->get_param()->server_param().\
|
|
|
|
|
downpour_server_param().\
|
|
|
|
|
downpour_table_param_size();
|
|
|
|
|
++i) {
|
|
|
|
|
if (_pslib_ptr->get_param()->server_param().\
|
|
|
|
|
downpour_server_param().downpour_table_param(i).\
|
|
|
|
|
table_class().find("SparseTable") != -1) {
|
|
|
|
|
_param_config.fea_dim = _pslib_ptr->get_param()->server_param().\
|
|
|
|
|
downpour_server_param().\
|
|
|
|
|
downpour_table_param(i).\
|
|
|
|
|
accessor().fea_dim();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
_param_config.slot_dim = _param_config.fea_dim - 2; //TODO
|
|
|
|
|
_param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
|
|
|
|
|
_param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
|
|
|
|
|
|
|
|
|
|
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); ++t) {
|
|
|
|
|
_param_config.skip_op.push_back(_pslib_ptr->get_param()->trainer_param().skip_op(t));
|
|
|
|
|
_param_config.slot_dim = _param_config.fea_dim - 2;
|
|
|
|
|
_param_config.tmp_push_dense_wait_times = static_cast<int32_t>(
|
|
|
|
|
_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
|
|
|
|
|
_param_config.tmp_push_sparse_wait_times = static_cast<int32_t>(
|
|
|
|
|
_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
|
|
|
|
|
|
|
|
|
|
for (auto t = 0u;
|
|
|
|
|
t < _pslib_ptr->get_param()->trainer_param().skip_op_size();
|
|
|
|
|
++t) {
|
|
|
|
|
_param_config.skip_op.push_back(
|
|
|
|
|
_pslib_ptr->get_param()->trainer_param().skip_op(t));
|
|
|
|
|
}
|
|
|
|
|
//sparse
|
|
|
|
|
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
|
|
|
|
|
|
|
|
|
|
for (auto t = 0u;
|
|
|
|
|
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size();
|
|
|
|
|
++t) {
|
|
|
|
|
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
|
|
|
|
|
std::vector<std::string> tmp_sparse_variable_name;
|
|
|
|
|
for (int i = 0u; i < table.slot_value_size(); ++i) {
|
|
|
|
|
tmp_sparse_variable_name.push_back(table.slot_value(i));
|
|
|
|
|
_param_config.slot_alias_to_table[table.slot_key(i)] = table.table_id();
|
|
|
|
|
_param_config.slot_alias_to_table[table.slot_key(i)] =
|
|
|
|
|
table.table_id();
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::string> tmp_sparse_gradient_variable_name;
|
|
|
|
|
for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
|
|
|
|
|
tmp_sparse_gradient_variable_name.push_back(
|
|
|
|
|
table.slot_gradient(i));
|
|
|
|
|
}
|
|
|
|
|
_param_config.slot_input_vec[table.table_id()] = std::move(tmp_sparse_variable_name);
|
|
|
|
|
_param_config.gradient_var[table.table_id()] = std::move(tmp_sparse_gradient_variable_name);
|
|
|
|
|
_param_config.slot_input_vec[table.table_id()] =
|
|
|
|
|
std::move(tmp_sparse_variable_name);
|
|
|
|
|
_param_config.gradient_var[table.table_id()] =
|
|
|
|
|
std::move(tmp_sparse_gradient_variable_name);
|
|
|
|
|
_param_config.sparse_table_id.push_back(table.table_id());
|
|
|
|
|
}
|
|
|
|
|
//dense
|
|
|
|
|
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) {
|
|
|
|
|
|
|
|
|
|
for (auto t = 0u;
|
|
|
|
|
t < _pslib_ptr->get_param()->trainer_param().dense_table_size();
|
|
|
|
|
++t) {
|
|
|
|
|
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
|
|
|
|
|
std::vector<std::string> tmp_dense_variable_name;
|
|
|
|
|
for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
|
|
|
|
@ -134,20 +161,18 @@ void AsyncExecutor::InitParamConfig() {
|
|
|
|
|
tmp_dense_gradient_variable_name.push_back(
|
|
|
|
|
table.dense_gradient_variable_name(i));
|
|
|
|
|
}
|
|
|
|
|
_param_config.dense_variable_name[table.table_id()] = std::move(tmp_dense_variable_name);
|
|
|
|
|
_param_config.dense_gradient_variable_name[table.table_id()] = std::move(tmp_dense_gradient_variable_name);
|
|
|
|
|
_param_config.dense_variable_name[table.table_id()] =
|
|
|
|
|
std::move(tmp_dense_variable_name);
|
|
|
|
|
_param_config.dense_gradient_variable_name[table.table_id()] =
|
|
|
|
|
std::move(tmp_dense_gradient_variable_name);
|
|
|
|
|
_param_config.dense_table_id.push_back(table.table_id());
|
|
|
|
|
_param_config.dense_table_size.push_back(table.fea_dim()); //TODO
|
|
|
|
|
_param_config.dense_table_size.push_back(table.fea_dim());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AsyncExecutor::InitModel() {
|
|
|
|
|
//TODO only rank = 0 do this
|
|
|
|
|
//std::vector<int> all_dense_table_id; //TODO
|
|
|
|
|
//all_dense_table_id.push_back(0); //done
|
|
|
|
|
for (auto table_id: _param_config.dense_table_id) {
|
|
|
|
|
for (auto table_id : _param_config.dense_table_id) {
|
|
|
|
|
std::vector<paddle::ps::Region> regions;
|
|
|
|
|
//std::vector<std::string> variables; //TODO
|
|
|
|
|
for (auto& t : _param_config.dense_variable_name[table_id]) {
|
|
|
|
|
Variable* var = root_scope_->FindVar(t);
|
|
|
|
|
CHECK(var != nullptr) << "var[" << t << "] not found";
|
|
|
|
@ -169,13 +194,15 @@ void AsyncExecutor::InitModel() {
|
|
|
|
|
regions.emplace_back(std::move(reg));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto push_status = _pslib_ptr->_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id);
|
|
|
|
|
auto push_status =
|
|
|
|
|
_pslib_ptr->_worker_ptr->push_dense_param(
|
|
|
|
|
regions.data(), regions.size(), table_id);
|
|
|
|
|
push_status.wait();
|
|
|
|
|
auto status = push_status.get();
|
|
|
|
|
if (status != 0) {
|
|
|
|
|
LOG(FATAL) << "push dense param failed, status[" << status << "]";
|
|
|
|
|
exit(-1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -185,7 +212,7 @@ void AsyncExecutor::SaveModel(const std::string& path) {
|
|
|
|
|
ret = _pslib_ptr->_worker_ptr->save(path, 0);
|
|
|
|
|
ret.wait();
|
|
|
|
|
int32_t feasign_cnt = ret.get();
|
|
|
|
|
if (feasign_cnt == -1) { // TODO should be feasign_cnt < 0, because server bug
|
|
|
|
|
if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0
|
|
|
|
|
LOG(FATAL) << "save model failed";
|
|
|
|
|
exit(-1);
|
|
|
|
|
}
|
|
|
|
@ -195,13 +222,13 @@ void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
|
|
|
|
|
if (mode == "mpi") {
|
|
|
|
|
DensePullThreadParam param;
|
|
|
|
|
param.ps_client = _pslib_ptr->_worker_ptr;;
|
|
|
|
|
param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO
|
|
|
|
|
param.threshold = 1;
|
|
|
|
|
param.training_thread_num = actual_thread_num;
|
|
|
|
|
param.root_scope = root_scope_;
|
|
|
|
|
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
|
|
|
|
|
param.dense_params = &_param_config.dense_variable_name;
|
|
|
|
|
|
|
|
|
|
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
|
|
|
|
|
_pull_dense_thread = std::shared_ptr<DensePullThread>(
|
|
|
|
|
new DensePullThread(param));
|
|
|
|
|
_pull_dense_thread->start();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|