refine pslib inferface & fix some bugs

revert-15207-remove_op_handle_lock_and_fix_var
heqiaozhi 6 years ago
parent d3ca359e44
commit 575ae7c6c3

@ -217,7 +217,7 @@ include(cupti)
include(external/gzstream)
endif (NOT WIN32)
include(external/libmct)
#include(external/pslib_brpc)
include(external/pslib_brpc)
include(external/pslib)
if(WITH_DISTRIBUTE)
@ -280,7 +280,7 @@ set(EXTERNAL_LIBS
zlib
${PYTHON_LIBRARIES}
pslib
#pslib_brpc
pslib_brpc
libmct
)

@ -65,18 +65,35 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
readers[0]->SetFileList(filelist);
}
void AsyncExecutor::ConfigPslib(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index) {
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib());
_pslib_ptr->init_and_config(dist_desc, host_sign_list, node_num, index);//TODO done
_pslib_ptr->init_server(dist_desc, index);//TODO done
InitParamConfig();
}
void AsyncExecutor::StartServer() {
void AsyncExecutor::InitWorker(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib());
_pslib_ptr->init_worker(dist_desc, host_sign_list.data(), node_num, index);//TODO done
InitParamConfig();
_pslib_ptr->run_server();
}
uint64_t AsyncExecutor::StartServer() {
return _pslib_ptr->run_server();
}
void AsyncExecutor::GatherServers(std::vector<uint64_t>& host_sign_list, int node_num) {
_pslib_ptr->gather_servers(host_sign_list.data(), node_num);
}
void AsyncExecutor::InitParamConfig() {
_param_config.fea_dim = _pslib_ptr->get_param()->trainer_param().sparse_table(0).feature_dim(); //TODO
for (int i = 0; i < _pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param_size(); ++i) {
if (_pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param(i).table_class().find("SparseTable") != -1) {
_param_config.fea_dim = _pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param(i).accessor().fea_dim(); //TODO
break;
}
}
_param_config.slot_dim = _param_config.fea_dim - 2; //TODO
_param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().pull_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
@ -176,6 +193,7 @@ void AsyncExecutor::PrepareDenseThread() {
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
@ -238,6 +256,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
fetch_var_names, root_scope_, thidx, debug);
}
// start executing ops in multiple threads
for (int thidx = 0; thidx < actual_thread_num; ++thidx) {
threads.push_back(

@ -63,9 +63,11 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_names,
const bool debug = false);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void ConfigPslib(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index);
void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index);
//void ConfigWorker() {}
void StartServer();
uint64_t StartServer();
void GatherServers(std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel();
void SaveModel(const std::string& path);
void InitParamConfig();

@ -345,9 +345,12 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() {
if (op->Type().find("sgd") != std::string::npos) {
continue;
}
if (op->Type().find("lookup_table") != std::string::npos ||
op->Type().find("lookup_table_grad") != std::string::npos) {
continue;
}
op->Run(*thread_scope_, place_);
}
UpdateParams();
}
@ -416,8 +419,8 @@ void AsyncExecutorThreadWorker::UpdateParams() {
for (auto i: _param_config->dense_table_id) {
PushDense(i);
}
int32_t tmp_push_dense_wait_times = _param_config->tmp_push_dense_wait_times; //TODO
int32_t tmp_push_sparse_wait_times = _param_config->tmp_push_sparse_wait_times; //TODO
int32_t tmp_push_dense_wait_times = -1;//_param_config->tmp_push_dense_wait_times; //TODO
int32_t tmp_push_sparse_wait_times = -1;//_param_config->tmp_push_sparse_wait_times; //TODO
static uint32_t push_dense_wait_times = static_cast<uint32_t>(tmp_push_dense_wait_times);
static uint32_t push_sparse_wait_times = static_cast<uint32_t>(tmp_push_sparse_wait_times);
@ -430,7 +433,6 @@ void AsyncExecutorThreadWorker::UpdateParams() {
if (tmp_push_dense_wait_times == -1) {
_push_dense_status.resize(0);
}
if (_push_sparse_status.size() >= push_sparse_wait_times) {
for (auto& t : _push_sparse_status) {
t.wait();
@ -440,7 +442,6 @@ void AsyncExecutorThreadWorker::UpdateParams() {
if (tmp_push_sparse_wait_times == -1) {
_push_sparse_status.resize(0);
}
//for (auto dense_table_id : GlobalConfig::instance().dense_table_id) {//TODO
for (auto dense_table_id: _param_config->dense_table_id) {
_pull_dense_thread->increase_thread_version(thread_id_, dense_table_id);
@ -451,8 +452,8 @@ void AsyncExecutorThreadWorker::UpdateParams() {
void AsyncExecutorThreadWorker::PushDense(int table_id) {
std::vector<paddle::ps::Region> regions;
//auto& variables = GlobalConfig::instance().dense_gradient_variable_name[table_id];
std::vector<std::string> variables;
for (auto& t : variables) {
//std::vector<std::string> variables;
for (auto& t : _param_config->dense_gradient_variable_name[table_id]) {
Variable* var = thread_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
@ -469,7 +470,6 @@ void AsyncExecutorThreadWorker::PushDense(int table_id) {
void AsyncExecutorThreadWorker::PullSparse(int table_id) {
auto& features = _features[table_id];
auto& feature_value = _feature_value[table_id];
auto fea_dim = _param_config->fea_dim; //TODO
@ -477,7 +477,6 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) {
features.clear();
features.resize(0);
features.reserve(MAX_FEASIGN_NUM);
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
// slot_idx = 0 is label TODO
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
@ -493,14 +492,14 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) {
features.push_back(static_cast<uint64_t>(ids[i]));
}
}
check_pull_push_memory(features, feature_value, fea_dim);
std::vector<float*> pull_feature_value;
for (auto i = 0u; i < features.size(); ++i) {
pull_feature_value.push_back(feature_value[i].data());
}
for (int i = 0; i < features.size(); ++i) {
}
auto status = _pslib_ptr->_worker_ptr->pull_sparse(
pull_feature_value.data(), table_id, features.data(), features.size());
_pull_sparse_status.push_back(std::move(status));
@ -532,10 +531,15 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) {
LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(_param_config->slot_input_vec[table_id][slot_idx - 1]);
LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
float* ptr = tensor_emb->data<float>();
float* ptr = tensor_emb->mutable_data<float>({len, slot_dim}, platform::CPUPlace());
memset(ptr, 0, sizeof(float) * len * slot_dim);
auto& tensor_lod = tensor->lod()[0];
LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod);
//float* ptr = tensor_emb->data<float>();
for (auto index = 0u; index < len; ++index){
//if (_current_train_job.use_cvm_feature()) {
@ -576,7 +580,6 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) {
//}
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
// slot_idx = 0 is label TODO
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
if (_param_config->slot_alias_to_table[feed_vec[slot_idx]] != table_id) {

@ -48,8 +48,10 @@ void BindAsyncExecutor(py::module* m) {
new framework::AsyncExecutor(scope, place));
}))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
.def("config_pslib", &framework::AsyncExecutor::ConfigPslib)
.def("init_server", &framework::AsyncExecutor::InitServer)
.def("init_worker", &framework::AsyncExecutor::InitWorker)
.def("start_server", &framework::AsyncExecutor::StartServer)
.def("gather_servers", &framework::AsyncExecutor::GatherServers)
.def("init_model", &framework::AsyncExecutor::InitModel)
.def("save_model", &framework::AsyncExecutor::SaveModel);
} // end BindAsyncExecutor

@ -158,8 +158,17 @@ class AsyncExecutor(object):
return
def init_server(self, filename, index):
self.executor.init_server(filename, index)
def init_worker(self, filename, ips, nodes_cnt, index):
self.executor.init_worker(filename, ips, nodes_cnt, index)
def start_server(self):
self.executor.start_server()
return self.executor.start_server()
def gather_servers(self, ips, nodes_cnt):
self.executor.gather_servers(ips, nodes_cnt)
def init_model(self):
self.executor.init_model()

@ -56,7 +56,7 @@ class DownpourSGD(object):
params_grads[0], params_grads[1])
ps_param = pslib.PSParameter()
ps_param.server_param.CopyFrom(server.get_desc())
ps_param.worker_param.CopyFrom(worker.get_desc())
ps_param.trainer_param.CopyFrom(worker.get_desc())
# Todo(guru4elephant): figure out how to support more sparse parameters
# currently only support lookup_table
worker_skipped_ops = ["lookup_table", "lookup_table_grad"]

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save