|
|
|
@ -86,7 +86,7 @@ class AsyncExecutor(object):
|
|
|
|
|
|
|
|
|
|
scope = global_scope()
|
|
|
|
|
self.executor = core.AsyncExecutor(scope, p)
|
|
|
|
|
self.instance = ps_instance.PaddlePSInstance("init_param", 1, 2)
|
|
|
|
|
self.instance = ps_instance.PaddlePSInstance(1, 2)
|
|
|
|
|
|
|
|
|
|
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
|
|
|
|
|
"""
|
|
|
|
@ -151,10 +151,7 @@ class AsyncExecutor(object):
|
|
|
|
|
self.executor.run_from_files(program_desc,
|
|
|
|
|
data_feed.desc(), filelist, thread_num,
|
|
|
|
|
fetch_var_names, debug)
|
|
|
|
|
self.instance.barrier_all() #worker do all things
|
|
|
|
|
if self.instance.is_first_worker():
|
|
|
|
|
self.executor.stop_server()
|
|
|
|
|
self.instance.barrier_all() #sync
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def config_distributed_nodes(self, dist_opt):
|
|
|
|
|
|
|
|
|
@ -167,8 +164,11 @@ class AsyncExecutor(object):
|
|
|
|
|
def get_instance(self):
|
|
|
|
|
return self.instance
|
|
|
|
|
|
|
|
|
|
#def stop_server(self):
|
|
|
|
|
# self.executor.stop_server()
|
|
|
|
|
def stop_server(self):
|
|
|
|
|
self.instance.barrier_all() #worker do all things
|
|
|
|
|
if self.instance.is_first_worker():
|
|
|
|
|
self.executor.stop_server()
|
|
|
|
|
self.instance.barrier_all() #sync
|
|
|
|
|
|
|
|
|
|
def init_server(self, dist_desc):
|
|
|
|
|
self.executor.init_server(dist_desc, self.instance._rankid)
|
|
|
|
|