download & run & instance

revert-15207-remove_op_handle_lock_and_fix_var
heqiaozhi 7 years ago
parent 57ac412b98
commit 10ed9e0a6e

@ -191,18 +191,19 @@ void AsyncExecutor::SaveModel(const std::string& path) {
} }
} }
void AsyncExecutor::PrepareDenseThread() { void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
DensePullThreadParam param; if (mode == "mpi") {
param.ps_client = _pslib_ptr->_worker_ptr;; DensePullThreadParam param;
param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO param.ps_client = _pslib_ptr->_worker_ptr;;
param.training_thread_num = actual_thread_num; param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO
param.root_scope = root_scope_; param.training_thread_num = actual_thread_num;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO param.root_scope = root_scope_;
param.dense_params = &_param_config.dense_variable_name; //param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start(); _pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
} }
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
@ -210,6 +211,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
const std::string& mode,
const bool debug) { const bool debug) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
@ -251,11 +253,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed // todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers; std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
PrepareDenseThread(); PrepareDenseThread(mode);
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers; std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num); workers.resize(actual_thread_num);
for (auto& worker : workers) { for (auto& worker : workers) {
worker.reset(new AsyncExecutorThreadWorker); if (mode == "mpi") {
worker.reset(new AsyncExecutorThreadWorker);
} else {
worker.reset(new ExecutorThreadWorker);
}
} }
// prepare thread resource here // prepare thread resource here
@ -274,7 +280,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
_pull_dense_thread->stop(); if (mode == "mpi") {
_pull_dense_thread->stop();
}
root_scope_->DropKids(); root_scope_->DropKids();
return; return;

@ -61,6 +61,7 @@ class AsyncExecutor {
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
const std::string& mode,
const bool debug = false); const bool debug = false);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index); //void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void InitServer(const std::string& dist_desc, int index); void InitServer(const std::string& dist_desc, int index);
@ -79,7 +80,7 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index, Scope* root_scope, const int thread_index,
const bool debug); const bool debug);
void PrepareDenseThread(); void PrepareDenseThread(const std::string& mode);
public: public:
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr; std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread; std::shared_ptr<DensePullThread> _pull_dense_thread;

@ -87,9 +87,8 @@ class AsyncExecutor(object):
scope = global_scope() scope = global_scope()
self.executor = core.AsyncExecutor(scope, p) self.executor = core.AsyncExecutor(scope, p)
self.instance = ps_instance.PaddlePSInstance(1, 2)
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): def run(self, program, data_feed, filelist, thread_num, fetch, mode="", debug=False):
""" """
Run program by this AsyncExecutor. Training dataset will be in filelist. Run program by this AsyncExecutor. Training dataset will be in filelist.
Users can also inspect certain variables by naming them in parameter Users can also inspect certain variables by naming them in parameter
@ -151,10 +150,11 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, debug) fetch_var_names, mode, debug)
def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12): def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12):
hadoop_home = "$HADOOP_HOME" #hadoop_home = "$HADOOP_HOME"
hadoop_home = "~/tools/hadoop-xingtian/hadoop/"
configs = { configs = {
"fs.default.name": fs_default_name, "fs.default.name": fs_default_name,
@ -169,8 +169,11 @@ class AsyncExecutor(object):
self.instance.get_worker_index(), self.instance.get_worker_index(),
self.instance.get_node_cnt() / 2, self.instance.get_node_cnt() / 2,
multi_processes=process_num) multi_processes=process_num)
self.instance.barrier_all() #wait for download_data #TODO only barriere worker
def config_distributed_nodes(self, dist_opt): def config_distributed_nodes(self):
self.instance = ps_instance.PaddlePSInstance(1, 2)
return self.instance
# get total rank # get total rank
# get rank index # get rank index
@ -196,11 +199,15 @@ class AsyncExecutor(object):
self.executor.gather_servers(ips, self.instance.get_node_cnt()) self.executor.gather_servers(ips, self.instance.get_node_cnt())
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
self.instance.barrier_all() #wait init model self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait for download_data self.instance.barrier_all() #wait for download_data #TODO remove this after only barrier worker
self.instance.barrier_all() #wait worker do all things self.instance.barrier_all() #wait worker do all things
self.instance.barrier_all() #sync self.instance.barrier_all() #sync
def init_worker(self, dist_desc, afs_path, local_path, fs_default_name, ugi): def init_worker(self, dist_desc, startup_program):
place = core.CPUPlace()
executor = Executor(place)
executor.run(startup_program)
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid) self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
@ -208,8 +215,6 @@ class AsyncExecutor(object):
if self.instance.is_first_worker(): if self.instance.is_first_worker():
self.executor.init_model() self.executor.init_model()
self.instance.barrier_all() #wait init model self.instance.barrier_all() #wait init model
self.download_data(afs_path, local_path, fs_default_name, ugi, process_num=12)
self.instance.barrier_all() #wait for download_data
def init_model(self): def init_model(self):
self.executor.init_model() self.executor.init_model()

Loading…
Cancel
Save